diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 73df2d9ed9..191c161613 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -9,8 +9,8 @@ from itertools import starmap from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from services.dataset_service import DatasetCollectionBindingService @@ -28,6 +28,7 @@ class DatasetCollectionBindingTestDataFactory: @staticmethod def create_collection_binding( + db_session_with_containers: Session, provider_name: str = "openai", model_name: str = "text-embedding-ada-002", collection_name: str = "collection-abc", @@ -51,8 +52,8 @@ class DatasetCollectionBindingTestDataFactory: collection_name=collection_name, type=collection_type, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -64,7 +65,7 @@ class TestDatasetCollectionBindingServiceGetBinding: including various provider/model combinations, collection types, and edge cases. """ - def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session): """ Test successful retrieval of an existing collection binding. @@ -77,6 +78,7 @@ class TestDatasetCollectionBindingServiceGetBinding: model_name = "text-embedding-ada-002" collection_type = "dataset" existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name=provider_name, model_name=model_name, collection_name="existing-collection", @@ -92,7 +94,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.id == existing_binding.id assert result.collection_name == "existing-collection" - def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session): """ Test successful creation of a new collection binding when none exists. @@ -116,7 +118,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.type == collection_type assert result.collection_name is not None - def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with different collection type.""" # Arrange provider_name = "openai" @@ -133,7 +135,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with default collection type parameter.""" # Arrange provider_name = "openai" @@ -147,7 +149,9 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_provider_model_combination( + self, db_session_with_containers: Session + ): """Test get_dataset_collection_binding with various provider/model combinations.""" # Arrange combinations = [ @@ -174,10 +178,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: including successful retrieval and error handling for missing bindings. """ - def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session): """Test successful retrieval of collection binding by ID and type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -194,7 +199,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.collection_name == "test-collection" assert result.type == "dataset" - def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session): """Test error handling when collection binding is not found by ID and type.""" # Arrange non_existent_id = str(uuid4()) @@ -203,10 +208,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") - def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID and type with different collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -222,10 +230,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "custom_type" - def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_default_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID with default collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -239,10 +250,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "dataset" - def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session): """Test error when binding exists but with wrong collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 9871ef37e6..4b98bddd26 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -10,9 +10,9 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum from models.model import App @@ -27,6 +27,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -37,13 +38,13 @@ class DatasetUpdateDeleteTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -51,14 +52,15 @@ class DatasetUpdateDeleteTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -78,12 +80,12 @@ class DatasetUpdateDeleteTestDataFactory: retrieval_model={"top_k": 2}, enable_api=enable_api, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App: + def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App: """Create a real app for AppDatasetJoin.""" app = App( tenant_id=tenant_id, @@ -96,16 +98,16 @@ class DatasetUpdateDeleteTestDataFactory: enable_api=True, created_by=created_by, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app @staticmethod - def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin: """Create a real AppDatasetJoin record.""" join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @@ -114,7 +116,7 @@ class TestDatasetServiceDeleteDataset: Comprehensive integration tests for DatasetService.delete_dataset method. """ - def test_delete_dataset_success(self, db_session_with_containers): + def test_delete_dataset_success(self, db_session_with_containers: Session): """ Test successful deletion of a dataset. @@ -130,8 +132,10 @@ class TestDatasetServiceDeleteDataset: - Method returns True """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted: @@ -139,10 +143,10 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None mock_dataset_was_deleted.send.assert_called_once_with(dataset) - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -156,7 +160,9 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) dataset_id = str(uuid4()) # Act @@ -165,7 +171,7 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is False - def test_delete_dataset_permission_denied_error(self, db_session_with_containers): + def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session): """ Test error handling when user lacks permission. @@ -178,19 +184,22 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL, tenant=tenant, ) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act & Assert with pytest.raises(NoPermissionError): DatasetService.delete_dataset(dataset.id, normal_user) # Verify no deletion was attempted - assert db.session.get(Dataset, dataset.id) is not None + assert db_session_with_containers.get(Dataset, dataset.id) is not None class TestDatasetServiceDatasetUseCheck: @@ -198,7 +207,7 @@ class TestDatasetServiceDatasetUseCheck: Comprehensive integration tests for DatasetService.dataset_use_check method. """ - def test_dataset_use_check_in_use(self, db_session_with_containers): + def test_dataset_use_check_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is in use. @@ -211,10 +220,12 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) - app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id) - DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id) + DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -222,7 +233,7 @@ class TestDatasetServiceDatasetUseCheck: # Assert assert result is True - def test_dataset_use_check_not_in_use(self, db_session_with_containers): + def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is not in use. @@ -235,8 +246,10 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -250,7 +263,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: Comprehensive integration tests for DatasetService.update_dataset_api_status method. """ - def test_update_dataset_api_status_enable_success(self, db_session_with_containers): + def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session): """ Test successful enabling of dataset API access. @@ -264,8 +277,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -276,12 +293,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is True assert dataset.updated_by == owner.id assert dataset.updated_at == current_time - def test_update_dataset_api_status_disable_success(self, db_session_with_containers): + def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session): """ Test successful disabling of dataset API access. @@ -295,8 +312,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=True + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -307,11 +328,11 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, False) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False assert dataset.updated_by == owner.id - def test_update_dataset_api_status_not_found_error(self, db_session_with_containers): + def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session): """ Test error handling when dataset is not found. @@ -330,7 +351,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(dataset_id, True) - def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers): + def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session): """ Test error handling when current_user is missing. @@ -343,8 +364,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - No updates are committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) # Act & Assert with ( @@ -354,6 +379,6 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Verify no commit was attempted - db.session.rollback() - db.session.refresh(dataset) + db_session_with_containers.rollback() + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 606e7e0b57..8595f5bf14 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -45,7 +46,7 @@ class TestAccountService: "passport_service": mock_passport_service, } - def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_login(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation and login with correct password. """ @@ -70,7 +71,9 @@ class TestAccountService: logged_in = AccountService.authenticate(email, password) assert logged_in.id == account.id - def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_without_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation without password (for OAuth users). """ @@ -92,7 +95,7 @@ class TestAccountService: assert account.password_salt is None def test_create_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account create with invalid new password format. @@ -113,7 +116,9 @@ class TestAccountService: password="invalid_new_password", ) - def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_registration_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when registration is disabled. """ @@ -131,7 +136,9 @@ class TestAccountService: password=fake.password(length=12), ) - def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when email is in freeze period. """ @@ -154,7 +161,9 @@ class TestAccountService: dify_config.BILLING_ENABLED = False # Reset config for other tests - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent account. """ @@ -164,7 +173,7 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) - def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. """ @@ -186,14 +195,13 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(AccountLoginError): AccountService.authenticate(email, password) - def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_wrong_password(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with wrong password. """ @@ -217,7 +225,9 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, wrong_password) - def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_with_invite_token( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invite token to set password for account without password. """ @@ -249,7 +259,7 @@ class TestAccountService: assert authenticated_account.password_salt is not None def test_authenticate_pending_account_activation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication activates pending account. @@ -270,16 +280,17 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Authenticate should activate the account authenticated_account = AccountService.authenticate(email, password) assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None - def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_password_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful password update. """ @@ -308,7 +319,7 @@ class TestAccountService: assert authenticated_account.id == account.id def test_update_account_password_wrong_current_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with wrong current password. @@ -335,7 +346,7 @@ class TestAccountService: AccountService.update_account_password(account, wrong_password, new_password) def test_update_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with invalid new password format. @@ -360,7 +371,7 @@ class TestAccountService: with pytest.raises(ValueError): # Password validation error AccountService.update_account_password(account, old_password, "123") - def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation with automatic tenant creation. """ @@ -387,14 +398,13 @@ class TestAccountService: assert account.email == email # Verify tenant was created and linked - from extensions.ext_database import db - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" def test_create_account_and_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace creation is disabled. @@ -419,7 +429,7 @@ class TestAccountService: ) def test_create_account_and_tenant_workspace_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace limit is exceeded. @@ -446,7 +456,9 @@ class TestAccountService: password=password, ) - def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + def test_link_account_integrate_new_provider( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test linking account with new OAuth provider. """ @@ -469,15 +481,18 @@ class TestAccountService: AccountService.link_account_integrate("new-google", "google_open_id_123", account) # Verify integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="new-google") + .first() + ) assert integration is not None assert integration.open_id == "google_open_id_123" def test_link_account_integrate_existing_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test linking account with existing provider (should update). @@ -504,15 +519,16 @@ class TestAccountService: AccountService.link_account_integrate("exists-google", "google_open_id_456", account) # Verify integration was updated - from extensions.ext_database import db from models import AccountIntegrate integration = ( - db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="exists-google") + .first() ) assert integration.open_id == "google_open_id_456" - def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_close_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test closing an account. """ @@ -536,12 +552,11 @@ class TestAccountService: AccountService.close_account(account) # Verify account status changed - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.CLOSED - def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating account fields. """ @@ -568,7 +583,9 @@ class TestAccountService: assert updated_account.name == updated_name assert updated_account.interface_theme == "dark" - def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_invalid_field( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating account with invalid field. """ @@ -591,7 +608,7 @@ class TestAccountService: with pytest.raises(AttributeError): AccountService.update_account(account, invalid_field="value") - def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_login_info(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating login information. """ @@ -616,13 +633,12 @@ class TestAccountService: AccountService.update_login_info(account, ip_address=ip_address) # Verify login info was updated - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.last_login_ip == ip_address assert account.last_login_at is not None - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login with token generation. """ @@ -659,7 +675,9 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_pending_account_activation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test login activates pending account. """ @@ -680,17 +698,16 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Login should activate the account token_pair = AccountService.login(account) - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.ACTIVE - def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + def test_logout(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test logout functionality. """ @@ -723,7 +740,7 @@ class TestAccountService: refresh_token_key = f"account_refresh_token:{account.id}" assert redis_client.get(refresh_token_key) is None - def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful token refresh. """ @@ -757,7 +774,7 @@ class TestAccountService: assert new_token_pair.access_token == "new_mock_access_token" assert new_token_pair.refresh_token != initial_token_pair.refresh_token - def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test refresh token with invalid token. """ @@ -766,7 +783,9 @@ class TestAccountService: with pytest.raises(ValueError, match="Invalid refresh token"): AccountService.refresh_token(invalid_token) - def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test refresh token with valid token but invalid account. """ @@ -791,16 +810,15 @@ class TestAccountService: token_pair = AccountService.login(account) # Delete account - from extensions.ext_database import db - db.session.delete(account) - db.session.commit() + db_session_with_containers.delete(account) + db_session_with_containers.commit() # Try to refresh token with deleted account with pytest.raises(ValueError, match="Invalid account"): AccountService.refresh_token(token_pair.refresh_token) - def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading user by ID successfully. """ @@ -830,7 +848,7 @@ class TestAccountService: assert loaded_user.id == account.id assert loaded_user.email == account.email - def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading non-existent user. """ @@ -839,7 +857,7 @@ class TestAccountService: loaded_user = AccountService.load_user(non_existent_user_id) assert loaded_user is None - def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading banned user raises Unauthorized. """ @@ -861,14 +879,13 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.load_user(account.id) - def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_account_jwt_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test JWT token generation for account. """ @@ -902,7 +919,7 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_logged_in_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading logged in account by ID. """ @@ -931,7 +948,9 @@ class TestAccountService: assert loaded_account is not None assert loaded_account.id == account.id - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email successfully. """ @@ -957,7 +976,9 @@ class TestAccountService: assert found_user is not None assert found_user.id == account.id - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through non-existent email. """ @@ -968,7 +989,7 @@ class TestAccountService: assert found_user is None def test_get_user_through_email_banned_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting banned user through email raises Unauthorized. @@ -991,14 +1012,15 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.get_user_through_email(email) - def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email that is in freeze period. """ @@ -1014,7 +1036,7 @@ class TestAccountService: # Reset config dify_config.BILLING_ENABLED = False - def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account deletion (should add task to queue and sync to enterprise). """ @@ -1050,7 +1072,7 @@ class TestAccountService: mock_delete_task.delay.assert_called_once_with(account.id) def test_generate_account_deletion_verification_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generating account deletion verification code. @@ -1079,7 +1101,9 @@ class TestAccountService: assert len(code) == 6 assert code.isdigit() - def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_valid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying valid account deletion code. """ @@ -1106,7 +1130,9 @@ class TestAccountService: is_valid = AccountService.verify_account_deletion_code(token, code) assert is_valid is True - def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying invalid account deletion code. """ @@ -1135,7 +1161,7 @@ class TestAccountService: assert is_valid is False def test_verify_account_deletion_code_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test verifying account deletion code with invalid token. @@ -1167,7 +1193,7 @@ class TestTenantService: "billing_service": mock_billing_service, } - def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant creation with default settings. """ @@ -1187,7 +1213,7 @@ class TestTenantService: assert tenant.encrypt_public_key is not None def test_create_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant creation when workspace creation is disabled. @@ -1202,7 +1228,9 @@ class TestTenantService: with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception TenantService.create_tenant(name=tenant_name) - def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_with_custom_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant creation with custom name and setup flag. """ @@ -1221,7 +1249,9 @@ class TestTenantService: assert tenant.status == "normal" assert tenant.encrypt_public_key is not None - def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful tenant member creation. """ @@ -1251,7 +1281,9 @@ class TestTenantService: assert tenant_member.account_id == account.id assert tenant_member.role == "admin" - def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_duplicate_owner( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test creating duplicate owner for a tenant (should fail). """ @@ -1290,7 +1322,9 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant already has an owner"): TenantService.create_tenant_member(tenant, account2, role="owner") - def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating role for existing tenant member. """ @@ -1323,7 +1357,7 @@ class TestTenantService: assert tenant_member2.account_id == tenant_member1.account_id assert tenant_member2.role == "editor" - def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_join_tenants_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting join tenants for an account. """ @@ -1361,7 +1395,7 @@ class TestTenantService: assert tenant2_name in tenant_names def test_get_current_tenant_by_account_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant by account successfully. @@ -1388,9 +1422,8 @@ class TestTenantService: # Add account to tenant and set as current TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get current tenant current_tenant = TenantService.get_current_tenant_by_account(account) @@ -1400,7 +1433,7 @@ class TestTenantService: assert current_tenant.role == "owner" def test_get_current_tenant_by_account_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant when account has no current tenant. @@ -1426,7 +1459,7 @@ class TestTenantService: with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) - def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant switching. """ @@ -1457,18 +1490,17 @@ class TestTenantService: # Set initial current tenant account.current_tenant = tenant1 - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Switch to second tenant TenantService.switch_tenant(account, tenant2.id) # Verify tenant was switched - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.current_tenant_id == tenant2.id - def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tenant switching without providing tenant ID. """ @@ -1493,7 +1525,9 @@ class TestTenantService: with pytest.raises(ValueError, match="Tenant ID must be provided"): TenantService.switch_tenant(account, None) - def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_account_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test switching to a tenant where account is not a member. """ @@ -1520,7 +1554,7 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): TenantService.switch_tenant(account, tenant.id) - def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking if tenant has specific roles. """ @@ -1570,7 +1604,7 @@ class TestTenantService: has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) assert has_normal is False - def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_invalid_role_type(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking roles with invalid role type. """ @@ -1589,7 +1623,7 @@ class TestTenantService: with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): TenantService.has_roles(tenant, [invalid_role]) - def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting user role in a tenant. """ @@ -1620,7 +1654,9 @@ class TestTenantService: assert user_role == "editor" - def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission successfully. """ @@ -1660,7 +1696,7 @@ class TestTenantService: TenantService.check_member_permission(tenant, owner_account, member_account, "add") def test_check_member_permission_invalid_action( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test checking member permission with invalid action. @@ -1692,7 +1728,9 @@ class TestTenantService: with pytest.raises(Exception, match="Invalid action"): TenantService.check_member_permission(tenant, account, None, invalid_action) - def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_operate_self( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission when trying to operate self. """ @@ -1722,7 +1760,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.check_member_permission(tenant, account, account, "remove") - def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful member removal from tenant (should sync to enterprise). """ @@ -1770,16 +1810,17 @@ class TestTenantService: ) # Verify member was removed - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join is None def test_remove_member_from_tenant_operate_self( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test removing member when trying to operate self. @@ -1810,7 +1851,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.remove_member_from_tenant(tenant, account, account) - def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test removing member who is not in the tenant. """ @@ -1849,7 +1892,7 @@ class TestTenantService: with pytest.raises(Exception, match="Member not in tenant"): TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) - def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful member role update. """ @@ -1889,15 +1932,16 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "admin", owner_account) # Verify role was updated - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join.role == "admin" - def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_to_owner(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating member role to owner (should change current owner to admin). """ @@ -1937,19 +1981,24 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "owner", owner_account) # Verify roles were updated correctly - from extensions.ext_database import db from models.account import TenantAccountJoin owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=owner_account.id) + .first() ) member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert owner_join.role == "admin" assert member_join.role == "owner" - def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_already_assigned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating member role to already assigned role. """ @@ -1989,7 +2038,7 @@ class TestTenantService: with pytest.raises(Exception, match="The provided role is already assigned to the member"): TenantService.update_member_role(tenant, member_account, "admin", owner_account) - def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_count_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant count successfully. """ @@ -2014,7 +2063,7 @@ class TestTenantService: assert tenant_count >= 3 def test_create_owner_tenant_if_not_exist_new_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant for new user without existing tenants. @@ -2044,17 +2093,16 @@ class TestTenantService: TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == workspace_name def test_create_owner_tenant_if_not_exist_existing_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when user already has a tenant. @@ -2083,20 +2131,19 @@ class TestTenantService: existing_tenant = TenantService.create_tenant(name=existing_tenant_name) TenantService.create_tenant_member(existing_tenant, account, role="owner") account.current_tenant = existing_tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Try to create owner tenant again (should not create new one) TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) # Verify no new tenant was created - tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + tenant_joins = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).all() assert len(tenant_joins) == 1 assert account.current_tenant.id == existing_tenant.id def test_create_owner_tenant_if_not_exist_workspace_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when workspace creation is disabled. @@ -2123,7 +2170,7 @@ class TestTenantService: with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) - def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_members_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant members successfully. """ @@ -2187,7 +2234,9 @@ class TestTenantService: elif member.email == normal_email: assert member.role == "normal" - def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_operator_members_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting dataset operator members successfully. """ @@ -2240,7 +2289,7 @@ class TestTenantService: assert dataset_operators[0].email == operator_email assert dataset_operators[0].role == "dataset_operator" - def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_custom_config_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting custom config successfully. """ @@ -2259,9 +2308,8 @@ class TestTenantService: # Set custom config custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} tenant.custom_config_dict = custom_config - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get custom config retrieved_config = TenantService.get_custom_config(tenant.id) @@ -2296,7 +2344,7 @@ class TestRegisterService: "passport_service": mock_passport_service, } - def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system setup with account creation and tenant setup. """ @@ -2309,11 +2357,10 @@ class TestRegisterService: mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False - from extensions.ext_database import db from models.model import DifySetup - db.session.query(DifySetup).delete() - db.session.commit() + db_session_with_containers.query(DifySetup).delete() + db_session_with_containers.commit() # Execute setup RegisterService.setup( @@ -2327,7 +2374,7 @@ class TestRegisterService: # Verify account was created from models import Account - account = db.session.query(Account).filter_by(email=admin_email).first() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() assert account is not None assert account.name == admin_name assert account.last_login_ip == ip_address @@ -2335,17 +2382,17 @@ class TestRegisterService: assert account.status == "active" # Verify DifySetup was created - dify_setup = db.session.query(DifySetup).first() + dify_setup = db_session_with_containers.query(DifySetup).first() assert dify_setup is not None # Verify tenant was created and linked from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_failure_rollback(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test setup failure with proper rollback of all created entities. """ @@ -2373,21 +2420,20 @@ class TestRegisterService: ) # Verify no entities were created (rollback worked) - from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin from models.model import DifySetup - account = db.session.query(Account).filter_by(email=admin_email).first() - tenant_count = db.session.query(Tenant).count() - tenant_join_count = db.session.query(TenantAccountJoin).count() - dify_setup_count = db.session.query(DifySetup).count() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() + tenant_count = db_session_with_containers.query(Tenant).count() + tenant_join_count = db_session_with_containers.query(TenantAccountJoin).count() + dify_setup_count = db_session_with_containers.query(DifySetup).count() assert account is None assert tenant_count == 0 assert tenant_join_count == 0 assert dify_setup_count == 0 - def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful account registration with workspace creation. """ @@ -2421,16 +2467,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == f"{name}'s Workspace" - def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_oauth(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration with OAuth integration. """ @@ -2467,14 +2512,19 @@ class TestRegisterService: assert account.initialized_at is not None # Verify OAuth integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider=provider) + .first() + ) assert integration is not None assert integration.open_id == open_id - def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_pending_status( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration with pending status. """ @@ -2511,14 +2561,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_creation_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace creation is disabled. """ @@ -2549,13 +2600,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_limit_exceeded( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace limit is exceeded. """ @@ -2589,13 +2641,12 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_without_workspace(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration without workspace creation. """ @@ -2624,13 +2675,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify no tenant was created - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_new_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a new member who doesn't have an account yet. """ @@ -2682,22 +2734,25 @@ class TestRegisterService: mock_send_mail.delay.assert_called_once() # Verify new account was created with pending status - from extensions.ext_database import db from models import Account, TenantAccountJoin - new_account = db.session.query(Account).filter_by(email=new_member_email).first() + new_account = db_session_with_containers.query(Account).filter_by(email=new_member_email).first() assert new_account is not None assert new_account.name == new_member_email.split("@")[0] # Default name from email assert new_account.status == "pending" # Verify tenant member was created tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=new_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "normal" - def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting an existing member who is not in the tenant yet. """ @@ -2749,16 +2804,19 @@ class TestRegisterService: mock_send_mail.delay.assert_not_called() # Verify tenant member was created for existing account - from extensions.ext_database import db from models.account import TenantAccountJoin tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=existing_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "admin" - def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member who is already in the tenant with pending status. """ @@ -2793,9 +2851,8 @@ class TestRegisterService: password=existing_pending_member_password, ) existing_account.status = "pending" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2820,7 +2877,9 @@ class TestRegisterService: # Verify email task was called mock_send_mail.delay.assert_called_once() - def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_no_inviter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member without providing an inviter. """ @@ -2846,7 +2905,7 @@ class TestRegisterService: ) def test_invite_new_member_account_already_in_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test inviting a member who is already in the tenant with active status. @@ -2882,9 +2941,8 @@ class TestRegisterService: password=already_in_tenant_password, ) existing_account.status = "active" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2899,7 +2957,9 @@ class TestRegisterService: inviter=inviter, ) - def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_invite_token_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation of invite token. """ @@ -2943,7 +3003,7 @@ class TestRegisterService: assert invitation_data["email"] == account.email assert invitation_data["workspace_id"] == tenant.id - def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_valid(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation of valid invite token. """ @@ -2974,7 +3034,9 @@ class TestRegisterService: # Verify token is valid assert is_valid is True - def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation of invalid invite token. """ @@ -2987,7 +3049,7 @@ class TestRegisterService: assert is_valid is False def test_revoke_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token with workspace ID and email. @@ -3030,7 +3092,7 @@ class TestRegisterService: assert redis_client.get(token_key) is not None def test_revoke_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token without workspace ID and email. @@ -3073,7 +3135,7 @@ class TestRegisterService: assert redis_client.get(token_key) is None def test_get_invitation_if_token_valid_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with valid token. @@ -3122,7 +3184,7 @@ class TestRegisterService: assert result["data"]["workspace_id"] == tenant.id def test_get_invitation_if_token_valid_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid token. @@ -3142,7 +3204,7 @@ class TestRegisterService: assert result is None def test_get_invitation_if_token_valid_invalid_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid tenant. @@ -3192,7 +3254,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_account_mismatch( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with account ID mismatch. @@ -3242,7 +3304,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_tenant_not_normal( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with tenant not in normal status. @@ -3269,9 +3331,8 @@ class TestRegisterService: # Change tenant status to non-normal tenant.status = "suspended" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Create a real token from extensions.ext_redis import redis_client @@ -3300,7 +3361,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_by_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token with workspace ID and email. @@ -3339,7 +3400,7 @@ class TestRegisterService: redis_client.delete(cache_key) def test_get_invitation_by_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token without workspace ID and email. @@ -3372,7 +3433,7 @@ class TestRegisterService: # Clean up redis_client.delete(token_key) - def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_invitation_token_key(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting invitation token key. """ diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index 00bce32f48..45839fd463 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account @@ -87,7 +88,7 @@ class TestAgentService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -133,13 +134,12 @@ class TestAgentService: # Update the app model config to set agent_mode for agent-chat mode if app.mode == "agent-chat" and app.app_model_config: app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []}) - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() return app, account - def _create_test_conversation_and_message(self, db_session_with_containers, app, account): + def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account): """ Helper method to create a test conversation and message with agent thoughts. @@ -153,8 +153,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation conversation = Conversation( id=fake.uuid4(), @@ -167,8 +165,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -180,12 +178,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -206,12 +204,12 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return conversation, message - def _create_test_agent_thoughts(self, db_session_with_containers, message): + def _create_test_agent_thoughts(self, db_session_with_containers: Session, message): """ Helper method to create test agent thoughts for a message. @@ -224,8 +222,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - agent_thoughts = [] # Create first agent thought @@ -251,7 +247,7 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought1) + db_session_with_containers.add(thought1) agent_thoughts.append(thought1) # Create second agent thought @@ -277,14 +273,14 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought2) + db_session_with_containers.add(thought2) agent_thoughts.append(thought2) - db.session.commit() + db_session_with_containers.commit() return agent_thoughts - def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of agent logs with complete data. """ @@ -344,7 +340,7 @@ class TestAgentService: assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon def test_get_agent_logs_conversation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when conversation is not found. @@ -358,7 +354,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Conversation not found"): AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4()) - def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_message_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when message is not found. """ @@ -372,7 +370,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Message not found"): AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4()) - def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when conversation is from end user. """ @@ -381,8 +381,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create end user end_user = EndUser( id=fake.uuid4(), @@ -393,8 +391,8 @@ class TestAgentService: session_id=fake.uuid4(), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create conversation with end user conversation = Conversation( @@ -408,8 +406,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -421,12 +419,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -447,8 +445,8 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -457,7 +455,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == end_user.name - def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_unknown_executor( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when executor is unknown. """ @@ -466,8 +466,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create conversation with non-existent account conversation = Conversation( id=fake.uuid4(), @@ -480,8 +478,8 @@ class TestAgentService: mode="chat", from_source="api", ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -493,12 +491,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -519,8 +517,8 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -529,7 +527,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == "Unknown" - def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_tool_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with tool errors. """ @@ -539,8 +539,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with tool error thought_with_error = MessageAgentThought( message_id=message.id, @@ -564,8 +562,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_error) - db.session.commit() + db_session_with_containers.add(thought_with_error) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -580,7 +578,7 @@ class TestAgentService: assert tool_call["error"] == "Tool execution failed" def test_get_agent_logs_without_agent_thoughts( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval when message has no agent thoughts. @@ -600,7 +598,7 @@ class TestAgentService: assert len(result["iterations"]) == 0 def test_get_agent_logs_app_model_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is not found. @@ -610,11 +608,9 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Remove app model config to test error handling app.app_model_config_id = None - db.session.commit() + db_session_with_containers.commit() # Create conversation without app model config conversation = Conversation( @@ -629,8 +625,8 @@ class TestAgentService: from_source="api", app_model_config_id=None, # Explicitly set to None ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -651,15 +647,15 @@ class TestAgentService: currency="USD", from_source="api", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test with pytest.raises(ValueError, match="App model config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) def test_get_agent_logs_agent_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when agent config is not found. @@ -677,7 +673,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Agent config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) - def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_agent_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful listing of agent providers. """ @@ -698,7 +696,7 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id)) - def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of specific agent provider. """ @@ -720,7 +718,9 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name) - def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_plugin_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when plugin daemon client raises an error. """ @@ -741,7 +741,7 @@ class TestAgentService: AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) def test_get_agent_logs_with_complex_tool_data( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with complex tool data and multiple tools. @@ -752,8 +752,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with multiple tools complex_thought = MessageAgentThought( message_id=message.id, @@ -799,8 +797,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(complex_thought) - db.session.commit() + db_session_with_containers.add(complex_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -831,7 +829,7 @@ class TestAgentService: assert tool_calls[2]["status"] == "success" assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon - def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test agent logs retrieval with message files and agent thought files. """ @@ -842,7 +840,6 @@ class TestAgentService: conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) from dify_graph.file import FileTransferMethod, FileType - from extensions.ext_database import db from models.enums import CreatorUserRole # Add files to message @@ -867,9 +864,9 @@ class TestAgentService: created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) - db.session.add(message_file1) - db.session.add(message_file2) - db.session.commit() + db_session_with_containers.add(message_file1) + db_session_with_containers.add(message_file2) + db_session_with_containers.commit() # Create agent thought with files thought_with_files = MessageAgentThought( @@ -895,8 +892,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_files) - db.session.commit() + db_session_with_containers.add(thought_with_files) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -912,7 +909,7 @@ class TestAgentService: assert "file2" in iterations[0]["files"] def test_get_agent_logs_with_different_timezone( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with different timezone settings. @@ -938,7 +935,9 @@ class TestAgentService: assert "T" in start_time # ISO format assert "+08:00" in start_time or "Z" in start_time # Timezone offset - def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_empty_tool_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with empty tool data. """ @@ -948,8 +947,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with empty tool data empty_thought = MessageAgentThought( message_id=message.id, @@ -964,8 +961,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(empty_thought) - db.session.commit() + db_session_with_containers.add(empty_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -979,7 +976,9 @@ class TestAgentService: tool_calls = iterations[0]["tool_calls"] assert len(tool_calls) == 0 # No tools to process - def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_malformed_json( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with malformed JSON data in tool fields. """ @@ -989,8 +988,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with malformed JSON malformed_thought = MessageAgentThought( message_id=message.id, @@ -1005,8 +1002,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(malformed_thought) - db.session.commit() + db_session_with_containers.add(malformed_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 4f5190e533..004d643955 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account @@ -52,7 +53,7 @@ class TestAnnotationService: "current_user": mock_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -115,11 +116,10 @@ class TestAnnotationService: tenant_id, ) - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -141,17 +141,16 @@ class TestAnnotationService: from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -180,12 +179,12 @@ class TestAnnotationService: from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_insert_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation. @@ -211,9 +210,8 @@ class TestAnnotationService: assert annotation.id is not None # Verify annotation was saved to database - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.id is not None # Verify add_annotation_to_index_task was called (when annotation setting exists) @@ -221,7 +219,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_insert_app_annotation_directly_requires_question( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Question must be provided when inserting annotations directly. @@ -238,7 +236,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) def test_insert_app_annotation_directly_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test direct insertion of app annotation when app is not found. @@ -260,7 +258,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) def test_update_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation. @@ -298,7 +296,7 @@ class TestAnnotationService: mock_external_service_dependencies["update_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_new( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating new annotation from message. @@ -307,8 +305,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { @@ -333,7 +331,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_update( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test updating existing annotation from message. @@ -342,8 +340,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial annotation initial_args = { @@ -373,7 +371,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message when app is not found. @@ -395,7 +393,7 @@ class TestAnnotationService: AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) def test_get_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of annotation list by app ID. @@ -428,7 +426,7 @@ class TestAnnotationService: assert annotation.account_id == account.id def test_get_annotation_list_by_app_id_with_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list with keyword search. @@ -462,7 +460,7 @@ class TestAnnotationService: assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. @@ -534,7 +532,7 @@ class TestAnnotationService: assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) def test_get_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list when app is not found. @@ -549,7 +547,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="App not found"): AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") - def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of app annotation. """ @@ -568,16 +568,19 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - from extensions.ext_database import db - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called (when annotation setting exists) # Note: In this test, no annotation setting exists, so task should not be called mock_external_service_dependencies["delete_task"].delay.assert_not_called() - def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deletion of app annotation when app is not found. """ @@ -593,7 +596,7 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) def test_delete_app_annotation_annotation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test deletion of app annotation when annotation is not found. @@ -606,7 +609,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="Annotation not found"): AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) - def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of app annotation. """ @@ -632,7 +637,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["enable_task"].delay.assert_called_once() - def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of app annotation. """ @@ -651,7 +658,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["disable_task"].delay.assert_called_once() - def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_cached_job( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test enabling app annotation when job is already cached. """ @@ -685,7 +694,9 @@ class TestAnnotationService: # Clean up redis_client.delete(enable_app_annotation_key) - def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_hit_histories_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation hit histories. """ @@ -728,7 +739,9 @@ class TestAnnotationService: assert history.app_id == app.id assert history.account_id == account.id - def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_annotation_history_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful addition of annotation history. """ @@ -763,16 +776,15 @@ class TestAnnotationService: ) # Verify hit count was incremented - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.hit_count == initial_hit_count + 1 # Verify history was created from models.model import AppAnnotationHitHistory history = ( - db.session.query(AppAnnotationHitHistory) + db_session_with_containers.query(AppAnnotationHitHistory) .where( AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id ) @@ -786,7 +798,9 @@ class TestAnnotationService: assert history.score == score assert history.source == "console" - def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_by_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation by ID. """ @@ -811,7 +825,9 @@ class TestAnnotationService: assert retrieved_annotation.content == annotation_args["answer"] assert retrieved_annotation.account_id == account.id - def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_batch_import_app_annotations_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful batch import of app annotations. """ @@ -854,7 +870,7 @@ class TestAnnotationService: mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() def test_batch_import_app_annotations_empty_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import with empty CSV file. @@ -889,7 +905,7 @@ class TestAnnotationService: assert "empty" in result["error_msg"].lower() def test_batch_import_app_annotations_quota_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import when quota is exceeded. @@ -935,7 +951,7 @@ class TestAnnotationService: assert "limit" in result["error_msg"].lower() def test_get_app_annotation_setting_by_app_id_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting enabled app annotation setting by app ID. @@ -944,7 +960,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -956,8 +971,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -967,8 +982,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Get annotation setting result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -981,7 +996,7 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" def test_get_app_annotation_setting_by_app_id_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting disabled app annotation setting by app ID. @@ -996,7 +1011,7 @@ class TestAnnotationService: assert result["enabled"] is False def test_update_app_annotation_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of app annotation setting. @@ -1005,7 +1020,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1017,8 +1031,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1028,8 +1042,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Update annotation setting update_args = { @@ -1046,11 +1060,11 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" # Verify database was updated - db.session.refresh(annotation_setting) + db_session_with_containers.refresh(annotation_setting) assert annotation_setting.score_threshold == 0.9 def test_export_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful export of annotation list by app ID. @@ -1083,7 +1097,7 @@ class TestAnnotationService: assert annotation.created_at <= exported_annotations[i - 1].created_at def test_export_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test export of annotation list when app is not found. @@ -1099,7 +1113,7 @@ class TestAnnotationService: AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) def test_insert_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation with annotation setting enabled. @@ -1108,7 +1122,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1120,8 +1133,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1131,8 +1144,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Setup annotation data annotation_args = { @@ -1161,7 +1174,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_update_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation with annotation setting enabled. @@ -1170,7 +1183,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1182,8 +1194,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1193,8 +1205,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # First, create an annotation original_args = { @@ -1234,7 +1246,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_delete_app_annotation_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful deletion of app annotation with annotation setting enabled. @@ -1243,7 +1255,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1255,8 +1266,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1267,8 +1278,8 @@ class TestAnnotationService: updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create an annotation first annotation_args = { @@ -1285,7 +1296,9 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called @@ -1297,7 +1310,7 @@ class TestAnnotationService: assert call_args[3] == collection_binding.id # collection_binding_id def test_up_insert_app_annotation_from_message_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message with annotation setting enabled. @@ -1306,7 +1319,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1318,8 +1330,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1329,12 +1341,12 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 8c8be2e670..b8bf8543bc 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.api_based_extension import APIBasedExtension from services.account_service import AccountService, TenantService @@ -31,7 +32,7 @@ class TestAPIBasedExtensionService: "requestor_instance": mock_requestor_instance, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -61,7 +62,7 @@ class TestAPIBasedExtensionService: return account, tenant - def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful saving of API-based extension. """ @@ -90,15 +91,16 @@ class TestAPIBasedExtensionService: assert saved_extension.created_at is not None # Verify extension was saved to database - from extensions.ext_database import db - db.session.refresh(saved_extension) + db_session_with_containers.refresh(saved_extension) assert saved_extension.id is not None # Verify ping connection was called mock_external_service_dependencies["requestor_instance"].request.assert_called_once() - def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_validation_errors( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation errors when saving extension with invalid data. """ @@ -132,7 +134,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all extensions by tenant ID. """ @@ -169,7 +173,7 @@ class TestAPIBasedExtensionService: # Verify descending order (newer first) assert extension.created_at <= extension_list[i - 1].created_at - def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of extension by tenant ID and extension ID. """ @@ -200,7 +204,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted assert retrieved_extension.created_at is not None - def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when extension is not found. """ @@ -214,7 +220,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) - def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of extension. """ @@ -238,12 +244,15 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.delete(created_extension) # Verify extension was deleted - from extensions.ext_database import db - deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + deleted_extension = ( + db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + ) assert deleted_extension is None - def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when saving extension with duplicate name. """ @@ -272,7 +281,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="name must be unique, it is already existed"): APIBasedExtensionService.save(extension_data2) - def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of existing extension. """ @@ -329,7 +340,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved - def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_connection_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test connection error when saving extension with invalid endpoint. """ @@ -356,7 +369,7 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.save(extension_data) def test_save_extension_invalid_api_key_length( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation error when saving extension with API key that is too short. @@ -378,7 +391,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must be at least 5 characters"): APIBasedExtensionService.save(extension_data) - def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation errors when saving extension with empty required fields. """ @@ -412,7 +425,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extensions when no extensions exist for tenant. """ @@ -428,7 +443,9 @@ class TestAPIBasedExtensionService: assert len(extension_list) == 0 assert extension_list == [] - def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_invalid_ping_response( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is invalid. """ @@ -452,7 +469,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'result': 'invalid'}"): APIBasedExtensionService.save(extension_data) - def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_missing_ping_result( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is missing result field. """ @@ -476,7 +495,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'status': 'ok'}"): APIBasedExtensionService.save(extension_data) - def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_wrong_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when tenant ID doesn't match. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 8544d23cdf..787a99f3e8 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -3,6 +3,7 @@ from unittest.mock import ANY, MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.model import EndUser @@ -118,7 +119,9 @@ class TestAppGenerateService: "global_dify_config": mock_global_dify_config, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"): + def _create_test_app_and_account( + self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat" + ): """ Helper method to create a test app and account for testing. @@ -169,7 +172,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers, app): + def _create_test_workflow(self, db_session_with_containers: Session, app): """ Helper method to create a test workflow for testing. @@ -191,14 +194,14 @@ class TestAppGenerateService: status="published", ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_completion_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for completion mode app. """ @@ -226,7 +229,7 @@ class TestAppGenerateService: mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once() - def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful generation for chat mode app. """ @@ -250,7 +253,9 @@ class TestAppGenerateService: mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_agent_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for agent chat mode app. """ @@ -274,7 +279,9 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_advanced_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for advanced chat mode app. """ @@ -300,7 +307,9 @@ class TestAppGenerateService: "advanced_chat_generator" ].return_value.convert_to_event_stream.assert_called_once() - def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_workflow_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for workflow mode app. """ @@ -324,7 +333,9 @@ class TestAppGenerateService: mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() - def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_specific_workflow_id( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with a specific workflow ID. """ @@ -355,7 +366,9 @@ class TestAppGenerateService: "workflow_service" ].return_value.get_published_workflow_by_id.assert_called_once() - def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_debugger_invoke_from( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with debugger invoke from. """ @@ -378,7 +391,9 @@ class TestAppGenerateService: # Verify draft workflow was fetched for debugger mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once() - def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_non_streaming_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with non-streaming mode. """ @@ -401,7 +416,7 @@ class TestAppGenerateService: # Verify rate limit exit was called for non-streaming mode mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with EndUser instead of Account. """ @@ -421,10 +436,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} @@ -438,7 +451,7 @@ class TestAppGenerateService: assert result == ["test_response"] def test_generate_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with billing enabled and sandbox plan. @@ -466,7 +479,9 @@ class TestAppGenerateService: # Verify billing service was called to consume quota mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() - def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_invalid_app_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with invalid app mode. """ @@ -491,7 +506,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_with_workflow_id_format_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with invalid workflow ID format. @@ -518,7 +533,7 @@ class TestAppGenerateService: assert "Invalid workflow_id format" in str(exc_info.value) def test_generate_with_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not found. @@ -552,7 +567,7 @@ class TestAppGenerateService: assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value) def test_generate_with_workflow_not_initialized_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not initialized for debugger. @@ -578,7 +593,7 @@ class TestAppGenerateService: assert "Workflow not initialized" in str(exc_info.value) def test_generate_with_workflow_not_published_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not published for non-debugger. @@ -604,7 +619,7 @@ class TestAppGenerateService: assert "Workflow not published" in str(exc_info.value) def test_generate_single_iteration_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for advanced chat mode. @@ -631,7 +646,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for workflow mode. @@ -658,7 +673,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_invalid_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test single iteration generation with invalid app mode. @@ -681,7 +696,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_single_loop_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for advanced chat mode. @@ -708,7 +723,7 @@ class TestAppGenerateService: ].return_value.single_loop_generate.assert_called_once() def test_generate_single_loop_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for workflow mode. @@ -732,7 +747,9 @@ class TestAppGenerateService: # Verify workflow generator was called mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once() - def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_single_loop_invalid_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test single loop generation with invalid app mode. """ @@ -753,7 +770,9 @@ class TestAppGenerateService: # Verify error message assert "Invalid app mode" in str(exc_info.value) - def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_more_like_this_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful more like this generation. """ @@ -778,7 +797,7 @@ class TestAppGenerateService: ].return_value.generate_more_like_this.assert_called_once() def test_generate_more_like_this_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test more like this generation with EndUser. @@ -799,10 +818,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() message_id = fake.uuid4() @@ -815,7 +832,7 @@ class TestAppGenerateService: assert result == ["more_like_this_response"] def test_get_max_active_requests_with_app_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with app-specific limit. @@ -835,7 +852,7 @@ class TestAppGenerateService: assert result == 10 def test_get_max_active_requests_with_config_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with config limit being smaller. @@ -856,7 +873,7 @@ class TestAppGenerateService: assert result <= 100 def test_get_max_active_requests_with_zero_limits( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with zero limits (infinite). @@ -875,7 +892,9 @@ class TestAppGenerateService: # Verify the result (should return config limit when app limit is 0) assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS - def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_exception_cleanup( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that rate limit exit is called when an exception occurs. """ @@ -904,7 +923,9 @@ class TestAppGenerateService: # Verify rate limit exit was called for cleanup mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_agent_mode_detection( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with agent mode detection based on app configuration. """ @@ -932,7 +953,7 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_different_invoke_from_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with different invoke from values. @@ -962,7 +983,7 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with complex arguments including files and external trace ID. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 745d6c97b0..fc3b20aaae 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account @@ -44,7 +45,7 @@ class TestAppService: "account_feature_service": mock_account_feature_service, } - def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app creation with basic parameters. """ @@ -98,7 +99,9 @@ class TestAppService: assert app.is_public is False assert app.is_universal is False - def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_with_different_modes( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app creation with different app modes. """ @@ -141,7 +144,7 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.created_by == account.id - def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app retrieval. """ @@ -189,7 +192,7 @@ class TestAppService: assert retrieved_app.tenant_id == created_app.tenant_id assert retrieved_app.created_by == created_app.created_by - def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful paginated app list retrieval. """ @@ -243,7 +246,9 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.mode == "chat" - def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with various filters. """ @@ -316,7 +321,9 @@ class TestAppService: my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) assert len(my_apps.items) == 1 - def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_tag_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with tag filters. """ @@ -386,7 +393,7 @@ class TestAppService: # Should return None when no apps match tag filter assert paginated_apps is None - def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app update with all fields. """ @@ -455,7 +462,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. """ @@ -508,7 +515,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app icon update. """ @@ -565,7 +572,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app site status update. """ @@ -623,7 +632,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_api_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app API status update. """ @@ -681,7 +692,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_no_change( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app site status update when status doesn't change. """ @@ -732,7 +745,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app deletion. """ @@ -778,12 +791,13 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_with_related_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app deletion with related data cleanup. """ @@ -839,12 +853,11 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app metadata retrieval. """ @@ -883,7 +896,7 @@ class TestAppService: assert "tool_icons" in app_meta # Note: get_app_meta currently only returns tool_icons - def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app code retrieval by app ID. """ @@ -923,7 +936,7 @@ class TestAppService: assert app_code is not None assert len(app_code) > 0 - def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app ID retrieval by app code. """ @@ -963,10 +976,9 @@ class TestAppService: site.status = "normal" site.default_language = "en-US" site.customize_token_strategy = "uuid" - from extensions.ext_database import db - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Get app ID by code app_id = AppService.get_app_id_by_code(site.code) @@ -974,7 +986,7 @@ class TestAppService: # Verify app ID was retrieved correctly assert app_id == app.id - def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test app creation with invalid mode. """ @@ -1010,7 +1022,7 @@ class TestAppService: app_service.create_app(tenant.id, app_args, account) def test_get_apps_with_special_characters_in_name( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test app retrieval with special characters in name search to verify SQL injection prevention. diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index c3decbf39d..102c1a1eb5 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,10 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.retrieval.retrieval_methods import RetrievalMethod from dify_graph.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from services.dataset_service import DatasetService @@ -25,7 +25,9 @@ class DatasetServiceIntegrationDataFactory: """Factory for creating real database entities used by integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create an account and tenant, then bind the account as current tenant member.""" account = Account( email=f"{uuid4()}@example.com", @@ -34,8 +36,8 @@ class DatasetServiceIntegrationDataFactory: status="active", ) tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -43,8 +45,8 @@ class DatasetServiceIntegrationDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.flush() + db_session_with_containers.add(join) + db_session_with_containers.flush() # Keep tenant context on the in-memory user without opening a separate session. account.role = role @@ -53,6 +55,7 @@ class DatasetServiceIntegrationDataFactory: @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -82,12 +85,14 @@ class DatasetServiceIntegrationDataFactory: collection_binding_id=collection_binding_id, chunk_structure=chunk_structure, ) - db.session.add(dataset) - db.session.flush() + db_session_with_containers.add(dataset) + db_session_with_containers.flush() return dataset @staticmethod - def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document: + def create_document( + db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt" + ) -> Document: """Create a document row belonging to the given dataset.""" document = Document( tenant_id=dataset.tenant_id, @@ -102,8 +107,8 @@ class DatasetServiceIntegrationDataFactory: indexing_status="completed", doc_form="text_model", ) - db.session.add(document) - db.session.flush() + db_session_with_containers.add(document) + db_session_with_containers.flush() return document @staticmethod @@ -118,10 +123,10 @@ class DatasetServiceIntegrationDataFactory: class TestDatasetServiceCreateDataset: """Integration coverage for DatasetService.create_empty_dataset.""" - def test_create_internal_dataset_basic_success(self, db_session_with_containers): + def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session): """Create a basic internal dataset with minimal configuration.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -133,17 +138,17 @@ class TestDatasetServiceCreateDataset: ) # Assert - created_dataset = db.session.get(Dataset, result.id) + created_dataset = db_session_with_containers.get(Dataset, result.id) assert created_dataset is not None assert created_dataset.provider == "vendor" assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_dataset.embedding_model_provider is None assert created_dataset.embedding_model is None - def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session): """Create an internal dataset with economy indexing and no embedding model.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -155,15 +160,15 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "economy" assert result.embedding_model_provider is None assert result.embedding_model is None - def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session): """Create a high-quality dataset and persist embedding model settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act @@ -179,7 +184,7 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name @@ -188,11 +193,12 @@ class TestDatasetServiceCreateDataset: model_type=ModelType.TEXT_EMBEDDING, ) - def test_create_dataset_duplicate_name_error(self, db_session_with_containers): + def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Raise duplicate-name error when the same tenant already has the name.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Duplicate Dataset", @@ -209,10 +215,10 @@ class TestDatasetServiceCreateDataset: account=account, ) - def test_create_external_dataset_success(self, db_session_with_containers): + def test_create_external_dataset_success(self, db_session_with_containers: Session): """Create an external dataset and persist external knowledge binding.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) external_knowledge_id = "knowledge-123" @@ -231,16 +237,16 @@ class TestDatasetServiceCreateDataset: ) # Assert - binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() + binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() assert result.provider == "external" assert binding is not None assert binding.external_knowledge_id == external_knowledge_id assert binding.external_knowledge_api_id == external_knowledge_api_id - def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers): + def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session): """Create a high-quality dataset with retrieval/reranking settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() retrieval_model = RetrievalModel( search_method=RetrievalMethod.SEMANTIC_SEARCH, @@ -271,14 +277,16 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.retrieval_model == retrieval_model.model_dump() mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") - def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(self, db_session_with_containers): + def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( + self, db_session_with_containers: Session + ): """Create high-quality dataset with explicitly configured embedding model.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_provider = "openai" embedding_model_name = "text-embedding-3-small" embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model( @@ -303,7 +311,7 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_provider assert result.embedding_model == embedding_model_name @@ -315,10 +323,10 @@ class TestDatasetServiceCreateDataset: model=embedding_model_name, ) - def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers): + def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session): """Persist retrieval model settings when creating an internal dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) retrieval_model = RetrievalModel( search_method=RetrievalMethod.SEMANTIC_SEARCH, reranking_enable=False, @@ -338,13 +346,13 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.retrieval_model == retrieval_model.model_dump() - def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers): + def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session): """Persist canonical custom permission when creating an internal dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -357,13 +365,13 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.permission == DatasetPermissionEnum.ALL_TEAM - def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers): + def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): """Raise error when external API template does not exist.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) # Act / Assert @@ -381,10 +389,10 @@ class TestDatasetServiceCreateDataset: external_knowledge_id="knowledge-123", ) - def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): """Raise error when external knowledge id is missing for external dataset creation.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) # Act / Assert @@ -406,10 +414,10 @@ class TestDatasetServiceCreateDataset: class TestDatasetServiceCreateRagPipelineDataset: """Integration coverage for DatasetService.create_empty_rag_pipeline_dataset.""" - def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session): """Create rag-pipeline dataset and pipeline rows when a name is provided.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="RAG Pipeline Dataset", @@ -425,8 +433,8 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - created_dataset = db.session.get(Dataset, result.id) - created_pipeline = db.session.get(Pipeline, result.pipeline_id) + created_dataset = db_session_with_containers.get(Dataset, result.id) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) assert created_dataset is not None assert created_dataset.name == entity.name assert created_dataset.runtime_mode == "rag_pipeline" @@ -436,10 +444,10 @@ class TestDatasetServiceCreateRagPipelineDataset: assert created_pipeline.name == entity.name assert created_pipeline.created_by == account.id - def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session): """Create rag-pipeline dataset with generated incremental name when input name is empty.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) generated_name = "Untitled 1" icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( @@ -460,25 +468,26 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - db.session.refresh(result) - created_pipeline = db.session.get(Pipeline, result.pipeline_id) + db_session_with_containers.refresh(result) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) assert result.name == generated_name assert created_pipeline is not None assert created_pipeline.name == generated_name mock_generate_name.assert_called_once() - def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Raise duplicate-name error when rag-pipeline dataset name already exists.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) duplicate_name = "Duplicate RAG Dataset" DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=duplicate_name, indexing_technique=None, ) - db.session.commit() + db_session_with_containers.commit() icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name=duplicate_name, @@ -496,10 +505,10 @@ class TestDatasetServiceCreateRagPipelineDataset: tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity ) - def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session): """Persist canonical custom permission for rag-pipeline dataset creation.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="Custom Permission RAG Dataset", @@ -515,13 +524,13 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.permission == DatasetPermissionEnum.ALL_TEAM - def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session): """Persist icon metadata when creating rag-pipeline dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) icon_info = IconInfo( icon="📚", icon_background="#E8F5E9", @@ -542,23 +551,25 @@ class TestDatasetServiceCreateRagPipelineDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.icon_info == icon_info.model_dump() class TestDatasetServiceUpdateAndDeleteDataset: """Integration coverage for SQL-backed update and delete behavior.""" - def test_update_dataset_duplicate_name_error(self, db_session_with_containers): + def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Reject update when target name already exists within the same tenant.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) source_dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Source Dataset", ) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Existing Dataset", @@ -568,17 +579,20 @@ class TestDatasetServiceUpdateAndDeleteDataset: with pytest.raises(ValueError, match="Dataset name already exists"): DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset that already has documents.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", chunk_structure="text_model", ) - DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id) + DatasetServiceIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, created_by=account.id + ) # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: @@ -586,14 +600,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete a dataset that has no documents and no indexing technique.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -606,14 +621,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete dataset when indexing_technique is None but doc_form path still exists.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -626,17 +642,17 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) class TestDatasetServiceRetrievalConfiguration: """Integration coverage for retrieval configuration persistence.""" - def test_get_dataset_retrieval_configuration(self, db_session_with_containers): + def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Return retrieval configuration that is persisted in SQL.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) retrieval_model = { "search_method": "semantic_search", "top_k": 5, @@ -644,6 +660,7 @@ class TestDatasetServiceRetrievalConfiguration: "reranking_enable": True, } dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, retrieval_model=retrieval_model, @@ -658,11 +675,12 @@ class TestDatasetServiceRetrievalConfiguration: assert result.retrieval_model["search_method"] == "semantic_search" assert result.retrieval_model["top_k"] == 5 - def test_update_dataset_retrieval_configuration(self, db_session_with_containers): + def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Persist retrieval configuration updates through DatasetService.update_dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", @@ -684,6 +702,6 @@ class TestDatasetServiceRetrievalConfiguration: result = DatasetService.update_dataset(dataset.id, update_data, account) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index ffdb501474..322b67d373 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -11,8 +11,8 @@ from unittest.mock import call, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.dataset import Dataset, Document from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -32,6 +32,7 @@ class DocumentBatchUpdateIntegrationDataFactory: @staticmethod def create_dataset( + db_session_with_containers: Session, dataset_id: str | None = None, tenant_id: str | None = None, name: str = "Test Dataset", @@ -47,12 +48,13 @@ class DocumentBatchUpdateIntegrationDataFactory: if dataset_id: dataset.id = dataset_id - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod def create_document( + db_session_with_containers: Session, dataset: Dataset, document_id: str | None = None, name: str = "test_document.pdf", @@ -89,13 +91,14 @@ class DocumentBatchUpdateIntegrationDataFactory: for key, value in kwargs.items(): setattr(document, key, value) - db.session.add(document) + db_session_with_containers.add(document) if commit: - db.session.commit() + db_session_with_containers.commit() return document @staticmethod def create_multiple_documents( + db_session_with_containers: Session, dataset: Dataset, document_ids: list[str], enabled: bool = True, @@ -106,6 +109,7 @@ class DocumentBatchUpdateIntegrationDataFactory: documents: list[Document] = [] for index, doc_id in enumerate(document_ids, start=1): document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, document_id=doc_id, name=f"document_{doc_id}.pdf", @@ -116,7 +120,7 @@ class DocumentBatchUpdateIntegrationDataFactory: commit=False, ) documents.append(document) - db.session.commit() + db_session_with_containers.commit() return documents @staticmethod @@ -173,13 +177,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: assert document.archived_at is None assert document.archived_by is None - def test_batch_update_enable_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Enable disabled documents and trigger indexing side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document_ids = [str(uuid4()), str(uuid4())] disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, dataset=dataset, document_ids=document_ids, enabled=False, @@ -192,7 +197,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Assert for document in disabled_docs: - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_enabled(document, FIXED_TIME) expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] @@ -203,13 +208,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls) def test_batch_update_enable_already_enabled_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip enable operation for already-enabled documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True) + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True + ) # Act DocumentService.batch_update_document_status( @@ -220,18 +227,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is True patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() - def test_batch_update_disable_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Disable completed documents and trigger remove-index tasks.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document_ids = [str(uuid4()), str(uuid4())] enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, dataset=dataset, document_ids=document_ids, enabled=True, @@ -248,7 +256,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Assert for document in enabled_docs: - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_disabled(document, user.id, FIXED_TIME) expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] @@ -259,13 +267,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls) def test_batch_update_disable_already_disabled_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip disable operation for already-disabled documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False, indexing_status="completed", @@ -281,17 +290,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(disabled_doc) + db_session_with_containers.refresh(disabled_doc) assert disabled_doc.enabled is False patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called() - def test_batch_update_disable_non_completed_document_error(self, db_session_with_containers, patched_dependencies): + def test_batch_update_disable_non_completed_document_error( + self, db_session_with_containers: Session, patched_dependencies + ): """Raise error when disabling a non-completed document.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, indexing_status="indexing", @@ -307,13 +319,13 @@ class TestDatasetServiceBatchUpdateDocumentStatus: user=user, ) - def test_batch_update_archive_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Archive enabled documents and trigger remove-index task.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=False + db_session_with_containers, dataset=dataset, enabled=True, archived=False ) # Act @@ -325,21 +337,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_archived(document, user.id, FIXED_TIME) patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) patched_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_archive_already_archived_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip archive operation for already-archived documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=True + db_session_with_containers, dataset=dataset, enabled=True, archived=True ) # Act @@ -351,20 +363,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.archived is True patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called() def test_batch_update_archive_disabled_document_no_index_removal( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Archive disabled document without index-removal side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=False, archived=False + db_session_with_containers, dataset=dataset, enabled=False, archived=False ) # Act @@ -376,18 +388,18 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_archived(document, user.id, FIXED_TIME) patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["remove_task"].delay.assert_not_called() - def test_batch_update_unarchive_documents_success(self, db_session_with_containers, patched_dependencies): + def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies): """Unarchive enabled documents and trigger add-index task.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=True + db_session_with_containers, dataset=dataset, enabled=True, archived=True ) # Act @@ -399,7 +411,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_unarchived(document) assert document.updated_at == FIXED_TIME patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") @@ -407,14 +419,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["add_task"].delay.assert_called_once_with(document.id) def test_batch_update_unarchive_already_unarchived_document_skipped( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Skip unarchive operation for already-unarchived documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=True, archived=False + db_session_with_containers, dataset=dataset, enabled=True, archived=False ) # Act @@ -426,20 +438,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.archived is False patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() def test_batch_update_unarchive_disabled_document_no_index_addition( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Unarchive disabled document without index-add side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( - dataset=dataset, enabled=False, archived=True + db_session_with_containers, dataset=dataset, enabled=False, archived=True ) # Act @@ -451,20 +463,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_unarchived(document) assert document.updated_at == FIXED_TIME patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() def test_batch_update_document_indexing_error_redis_cache_hit( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Raise DocumentIndexingError when redis indicates active indexing.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, name="test_document.pdf", enabled=True, @@ -483,12 +496,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: assert "test_document.pdf" in str(exc_info.value) patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") - def test_batch_update_async_task_error_handling(self, db_session_with_containers, patched_dependencies): + def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies): """Persist DB update, then propagate async task error.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error") # Act / Assert @@ -500,14 +515,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus: user=user, ) - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_enabled(document, FIXED_TIME) patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) - def test_batch_update_empty_document_list(self, db_session_with_containers, patched_dependencies): + def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies): """Return early when document_ids is empty.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() # Act @@ -520,10 +535,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].get.assert_not_called() patched_dependencies["redis_client"].setex.assert_not_called() - def test_batch_update_document_not_found_skipped(self, db_session_with_containers, patched_dependencies): + def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies): """Skip IDs that do not map to existing dataset documents.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() missing_document_id = str(uuid4()) @@ -540,18 +555,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["redis_client"].setex.assert_not_called() patched_dependencies["add_task"].delay.assert_not_called() - def test_batch_update_mixed_document_states_and_actions(self, db_session_with_containers, patched_dependencies): + def test_batch_update_mixed_document_states_and_actions( + self, db_session_with_containers: Session, patched_dependencies + ): """Process only the applicable document in a mixed-state enable batch.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=2, ) archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True, @@ -568,9 +589,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(disabled_doc) - db.session.refresh(enabled_doc) - db.session.refresh(archived_doc) + db_session_with_containers.refresh(disabled_doc) + db_session_with_containers.refresh(enabled_doc) + db_session_with_containers.refresh(archived_doc) self._assert_document_enabled(disabled_doc, FIXED_TIME) assert enabled_doc.enabled is True assert archived_doc.enabled is True @@ -582,13 +603,16 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id) - def test_batch_update_large_document_list_performance(self, db_session_with_containers, patched_dependencies): + def test_batch_update_large_document_list_performance( + self, db_session_with_containers: Session, patched_dependencies + ): """Handle large document lists with consistent updates and side effects.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() document_ids = [str(uuid4()) for _ in range(100)] documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, dataset=dataset, document_ids=document_ids, enabled=False, @@ -604,7 +628,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus: # Assert for document in documents: - db.session.refresh(document) + db_session_with_containers.refresh(document) self._assert_document_enabled(document, FIXED_TIME) assert patched_dependencies["redis_client"].setex.call_count == len(document_ids) @@ -616,17 +640,26 @@ class TestDatasetServiceBatchUpdateDocumentStatus: patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) def test_batch_update_mixed_document_states_complex_scenario( - self, db_session_with_containers, patched_dependencies + self, db_session_with_containers: Session, patched_dependencies ): """Process a complex mixed-state batch and update only eligible records.""" # Arrange - dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset() + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) user = DocumentBatchUpdateIntegrationDataFactory.create_user() - doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False) - doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=2) - doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=3) - doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=4) + doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=2 + ) + doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=3 + ) + doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=4 + ) doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True, @@ -645,11 +678,11 @@ class TestDatasetServiceBatchUpdateDocumentStatus: ) # Assert - db.session.refresh(doc1) - db.session.refresh(doc2) - db.session.refresh(doc3) - db.session.refresh(doc4) - db.session.refresh(doc5) + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + db_session_with_containers.refresh(doc3) + db_session_with_containers.refresh(doc4) + db_session_with_containers.refresh(doc5) self._assert_document_enabled(doc1, FIXED_TIME) assert doc2.enabled is True assert doc3.enabled is True diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index 6effe795e2..e78894fcae 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -10,7 +10,8 @@ Tests the retrieval of document segments with pagination and filtering: from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment from services.dataset_service import SegmentService @@ -23,6 +24,7 @@ class SegmentServiceTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -33,13 +35,13 @@ class SegmentServiceTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -47,14 +49,14 @@ class SegmentServiceTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_dataset(tenant_id: str, created_by: str) -> Dataset: + def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset: """Create a real dataset.""" dataset = Dataset( tenant_id=tenant_id, @@ -67,12 +69,14 @@ class SegmentServiceTestDataFactory: provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document: + def create_document( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str + ) -> Document: """Create a real document.""" document = Document( tenant_id=tenant_id, @@ -84,12 +88,13 @@ class SegmentServiceTestDataFactory: created_from="api", created_by=created_by, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @staticmethod def create_segment( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, document_id: str, @@ -112,8 +117,8 @@ class SegmentServiceTestDataFactory: tokens=tokens, created_by=created_by, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment @@ -130,7 +135,7 @@ class TestSegmentServiceGetSegments: - Combined filters """ - def test_get_segments_basic_pagination(self, db_session_with_containers): + def test_get_segments_basic_pagination(self, db_session_with_containers: Session): """ Test basic pagination functionality. @@ -140,11 +145,14 @@ class TestSegmentServiceGetSegments: - Returns segments and total count """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) segment1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -153,6 +161,7 @@ class TestSegmentServiceGetSegments: content="First segment", ) segment2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -170,7 +179,7 @@ class TestSegmentServiceGetSegments: assert items[0].id == segment1.id assert items[1].id == segment2.id - def test_get_segments_with_status_filter(self, db_session_with_containers): + def test_get_segments_with_status_filter(self, db_session_with_containers: Session): """ Test filtering by status list. @@ -179,11 +188,14 @@ class TestSegmentServiceGetSegments: - Only segments with matching status are returned """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -192,6 +204,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -200,6 +213,7 @@ class TestSegmentServiceGetSegments: status="indexing", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -219,7 +233,7 @@ class TestSegmentServiceGetSegments: statuses = {item.status for item in items} assert statuses == {"completed", "indexing"} - def test_get_segments_with_empty_status_list(self, db_session_with_containers): + def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session): """ Test with empty status list. @@ -228,11 +242,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied to avoid WHERE false condition """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -241,6 +258,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -256,7 +274,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_with_keyword_search(self, db_session_with_containers): + def test_get_segments_with_keyword_search(self, db_session_with_containers: Session): """ Test keyword search functionality. @@ -265,11 +283,14 @@ class TestSegmentServiceGetSegments: - Search pattern includes wildcards (%keyword%) """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -278,6 +299,7 @@ class TestSegmentServiceGetSegments: content="This contains search term in the middle", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -294,7 +316,7 @@ class TestSegmentServiceGetSegments: assert total == 1 assert "search term" in items[0].content - def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers): + def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session): """ Test ordering by position and id. @@ -304,12 +326,15 @@ class TestSegmentServiceGetSegments: - This prevents duplicate data across pages when positions are not unique """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with different positions seg_pos2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -318,6 +343,7 @@ class TestSegmentServiceGetSegments: content="Position 2", ) seg_pos1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -326,6 +352,7 @@ class TestSegmentServiceGetSegments: content="Position 1", ) seg_pos3 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -344,7 +371,7 @@ class TestSegmentServiceGetSegments: assert items[1].id == seg_pos2.id assert items[2].id == seg_pos3.id - def test_get_segments_empty_results(self, db_session_with_containers): + def test_get_segments_empty_results(self, db_session_with_containers: Session): """ Test when no segments match the criteria. @@ -353,7 +380,7 @@ class TestSegmentServiceGetSegments: - Total count is 0 """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) non_existent_doc_id = str(uuid4()) # Act @@ -363,7 +390,7 @@ class TestSegmentServiceGetSegments: assert items == [] assert total == 0 - def test_get_segments_combined_filters(self, db_session_with_containers): + def test_get_segments_combined_filters(self, db_session_with_containers: Session): """ Test with multiple filters combined. @@ -372,12 +399,15 @@ class TestSegmentServiceGetSegments: - Status list and keyword search both applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with various statuses and content SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -387,6 +417,7 @@ class TestSegmentServiceGetSegments: content="This is important information", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -396,6 +427,7 @@ class TestSegmentServiceGetSegments: content="This is also important", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -421,7 +453,7 @@ class TestSegmentServiceGetSegments: assert items[0].status == "completed" assert "important" in items[0].content - def test_get_segments_with_none_status_list(self, db_session_with_containers): + def test_get_segments_with_none_status_list(self, db_session_with_containers: Session): """ Test with None status list. @@ -430,11 +462,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -443,6 +478,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -462,7 +498,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers): + def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session): """ Test that max_per_page is correctly set to 100. @@ -471,13 +507,16 @@ class TestSegmentServiceGetSegments: - This prevents excessive page sizes """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create 105 segments to exceed max_per_page of 100 for i in range(105): SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index f605a286ed..8bd994937a 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -13,7 +13,8 @@ This test suite covers: import json from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -31,7 +32,9 @@ class DatasetRetrievalTestDataFactory: """Factory class for creating database-backed test data for dataset retrieval integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL + ) -> tuple[Account, Tenant]: """Create an account and tenant with the specified role.""" account = Account( email=f"{uuid4()}@example.com", @@ -43,8 +46,8 @@ class DatasetRetrievalTestDataFactory: name=f"tenant-{uuid4()}", status="normal", ) - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -52,14 +55,16 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: + def create_account_in_tenant( + db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> Account: """Create an account and add it to an existing tenant.""" account = Account( email=f"{uuid4()}@example.com", @@ -67,8 +72,8 @@ class DatasetRetrievalTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -76,14 +81,15 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -101,12 +107,14 @@ class DatasetRetrievalTestDataFactory: provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission: + def create_dataset_permission( + db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str + ) -> DatasetPermission: """Create a dataset permission.""" permission = DatasetPermission( dataset_id=dataset_id, @@ -114,12 +122,14 @@ class DatasetRetrievalTestDataFactory: account_id=account_id, has_permission=True, ) - db.session.add(permission) - db.session.commit() + db_session_with_containers.add(permission) + db_session_with_containers.commit() return permission @staticmethod - def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule: + def create_process_rule( + db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict + ) -> DatasetProcessRule: """Create a dataset process rule.""" process_rule = DatasetProcessRule( dataset_id=dataset_id, @@ -127,12 +137,14 @@ class DatasetRetrievalTestDataFactory: mode=mode, rules=json.dumps(rules), ) - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule @staticmethod - def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery: + def create_dataset_query( + db_session_with_containers: Session, dataset_id: str, created_by: str, content: str + ) -> DatasetQuery: """Create a dataset query.""" dataset_query = DatasetQuery( dataset_id=dataset_id, @@ -142,23 +154,23 @@ class DatasetRetrievalTestDataFactory: created_by_role="account", created_by=created_by, ) - db.session.add(dataset_query) - db.session.commit() + db_session_with_containers.add(dataset_query) + db_session_with_containers.commit() return dataset_query @staticmethod - def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin: """Create an app-dataset join.""" join = AppDatasetJoin( app_id=str(uuid4()), dataset_id=dataset_id, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @staticmethod - def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag: + def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag: """Create a knowledge tag and bind it to the target dataset.""" tag = Tag( tenant_id=tenant_id, @@ -166,8 +178,8 @@ class DatasetRetrievalTestDataFactory: name=f"tag-{uuid4()}", created_by=created_by, ) - db.session.add(tag) - db.session.flush() + db_session_with_containers.add(tag) + db_session_with_containers.flush() binding = TagBinding( tenant_id=tenant_id, @@ -175,8 +187,8 @@ class DatasetRetrievalTestDataFactory: target_id=target_id, created_by=created_by, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return tag @@ -195,15 +207,16 @@ class TestDatasetServiceGetDatasets: # ==================== Basic Retrieval Tests ==================== - def test_get_datasets_basic_pagination(self, db_session_with_containers): + def test_get_datasets_basic_pagination(self, db_session_with_containers: Session): """Test basic pagination without user or filters.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 for i in range(5): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"Dataset {i}", @@ -217,21 +230,23 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 5 assert total == 5 - def test_get_datasets_with_search(self, db_session_with_containers): + def test_get_datasets_with_search(self, db_session_with_containers: Session): """Test get_datasets with search keyword.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 search = "test" DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Test Dataset", permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Another Dataset", @@ -245,26 +260,32 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_with_tag_filtering(self, db_session_with_containers): + def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session): """Test get_datasets with tag_ids filtering.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 dataset_1 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) dataset_2 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) - tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id) - tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id) + tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_1.id + ) + tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_2.id + ) tag_ids = [tag_1.id, tag_2.id] # Act @@ -274,16 +295,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 2 assert total == 2 - def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers): + def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session): """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 tag_ids = [] for i in range(3): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"dataset-{i}", @@ -300,19 +322,21 @@ class TestDatasetServiceGetDatasets: # ==================== Permission-Based Filtering Tests ==================== - def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers): + def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session): """Test that without user, only ALL_TEAM datasets are shown.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -325,15 +349,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_owner_with_include_all(self, db_session_with_containers): + def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session): """Test that OWNER with include_all=True sees all datasets.""" # Arrange - owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) for i, permission in enumerate( [DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM] ): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, name=f"dataset-{i}", @@ -353,12 +380,15 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 3 assert total == 3 - def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session): """Test that normal user sees ONLY_ME datasets they created.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -371,13 +401,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session): """Test that normal user sees ALL_TEAM datasets.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -390,18 +425,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session): """Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.PARTIAL_TEAM, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, user.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) @@ -410,20 +452,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ONLY_ME, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, operator.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) @@ -432,14 +479,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR without permissions returns empty result.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -456,11 +506,13 @@ class TestDatasetServiceGetDatasets: class TestDatasetServiceGetDataset: """Comprehensive integration tests for DatasetService.get_dataset method.""" - def test_get_dataset_success(self, db_session_with_containers): + def test_get_dataset_success(self, db_session_with_containers: Session): """Test successful retrieval of a single dataset.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_dataset(dataset.id) @@ -469,7 +521,7 @@ class TestDatasetServiceGetDataset: assert result is not None assert result.id == dataset.id - def test_get_dataset_not_found(self, db_session_with_containers): + def test_get_dataset_not_found(self, db_session_with_containers: Session): """Test retrieval when dataset doesn't exist.""" # Arrange dataset_id = str(uuid4()) @@ -484,12 +536,15 @@ class TestDatasetServiceGetDataset: class TestDatasetServiceGetDatasetsByIds: """Comprehensive integration tests for DatasetService.get_datasets_by_ids method.""" - def test_get_datasets_by_ids_success(self, db_session_with_containers): + def test_get_datasets_by_ids_success(self, db_session_with_containers: Session): """Test successful bulk retrieval of datasets by IDs.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) datasets = [ - DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3) + DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) + for _ in range(3) ] dataset_ids = [dataset.id for dataset in datasets] @@ -501,7 +556,7 @@ class TestDatasetServiceGetDatasetsByIds: assert total == 3 assert all(dataset.id in dataset_ids for dataset in result_datasets) - def test_get_datasets_by_ids_empty_list(self, db_session_with_containers): + def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with empty list returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -514,7 +569,7 @@ class TestDatasetServiceGetDatasetsByIds: assert datasets == [] assert total == 0 - def test_get_datasets_by_ids_none_list(self, db_session_with_containers): + def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with None returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -530,17 +585,20 @@ class TestDatasetServiceGetDatasetsByIds: class TestDatasetServiceGetProcessRules: """Comprehensive integration tests for DatasetService.get_process_rules method.""" - def test_get_process_rules_with_existing_rule(self, db_session_with_containers): + def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when rule exists.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) rules_data = { "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], "segmentation": {"delimiter": "\n", "max_tokens": 500}, } DatasetRetrievalTestDataFactory.create_process_rule( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, mode="custom", @@ -554,11 +612,13 @@ class TestDatasetServiceGetProcessRules: assert result["mode"] == "custom" assert result["rules"] == rules_data - def test_get_process_rules_without_existing_rule(self, db_session_with_containers): + def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when no rule exists (returns defaults).""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_process_rules(dataset.id) @@ -572,16 +632,19 @@ class TestDatasetServiceGetProcessRules: class TestDatasetServiceGetDatasetQueries: """Comprehensive integration tests for DatasetService.get_dataset_queries method.""" - def test_get_dataset_queries_success(self, db_session_with_containers): + def test_get_dataset_queries_success(self, db_session_with_containers: Session): """Test successful retrieval of dataset queries.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 for i in range(3): DatasetRetrievalTestDataFactory.create_dataset_query( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, content=f"query-{i}", @@ -595,11 +658,13 @@ class TestDatasetServiceGetDatasetQueries: assert total == 3 assert all(query.dataset_id == dataset.id for query in queries) - def test_get_dataset_queries_empty_result(self, db_session_with_containers): + def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session): """Test retrieval when no queries exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 @@ -614,14 +679,16 @@ class TestDatasetServiceGetDatasetQueries: class TestDatasetServiceGetRelatedApps: """Comprehensive integration tests for DatasetService.get_related_apps method.""" - def test_get_related_apps_success(self, db_session_with_containers): + def test_get_related_apps_success(self, db_session_with_containers: Session): """Test successful retrieval of related apps.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) for _ in range(2): - DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id) + DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id) # Act result = DatasetService.get_related_apps(dataset.id) @@ -630,11 +697,13 @@ class TestDatasetServiceGetRelatedApps: assert len(result) == 2 assert all(join.dataset_id == dataset.id for join in result) - def test_get_related_apps_empty_result(self, db_session_with_containers): + def test_get_related_apps_empty_result(self, db_session_with_containers: Session): """Test retrieval when no related apps exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_related_apps(dataset.id) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 7f9135bb81..ebaa3b4637 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,9 +2,9 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from dify_graph.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from services.dataset_service import DatasetService @@ -15,7 +15,9 @@ class DatasetUpdateTestDataFactory: """Factory class for creating real test data for dataset update integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create a real account and tenant with the given role.""" account = Account( email=f"{uuid4()}@example.com", @@ -23,12 +25,12 @@ class DatasetUpdateTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant(name=f"tenant-{account.id}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -36,14 +38,15 @@ class DatasetUpdateTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, provider: str = "vendor", @@ -71,12 +74,13 @@ class DatasetUpdateTestDataFactory: embedding_model=embedding_model, collection_binding_id=collection_binding_id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod def create_external_binding( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str, @@ -93,8 +97,8 @@ class DatasetUpdateTestDataFactory: external_knowledge_id=external_knowledge_id, external_knowledge_api_id=external_knowledge_api_id, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -112,10 +116,11 @@ class TestDatasetServiceUpdateDataset: # ==================== External Dataset Tests ==================== - def test_update_external_dataset_success(self, db_session_with_containers): + def test_update_external_dataset_success(self, db_session_with_containers: Session): """Test successful update of external dataset.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -124,12 +129,13 @@ class TestDatasetServiceUpdateDataset: retrieval_model="old_model", ) binding = DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, ) binding_id = binding.id - db.session.expunge(binding) + db_session_with_containers.expunge(binding) update_data = { "name": "new_name", @@ -142,8 +148,8 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) - updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() + db_session_with_containers.refresh(dataset) + updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -153,15 +159,17 @@ class TestDatasetServiceUpdateDataset: assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] assert result.id == dataset.id - def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): """Test error when external knowledge id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -173,17 +181,19 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): """Test error when external knowledge api id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -195,12 +205,13 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge api id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers): + def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session): """Test error when external knowledge binding is not found.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -216,15 +227,16 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge binding not found" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() # ==================== Internal Dataset Basic Tests ==================== - def test_update_internal_dataset_basic_success(self, db_session_with_containers): + def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session): """Test successful update of internal dataset with basic fields.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -244,7 +256,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -254,11 +266,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.embedding_model == "text-embedding-ada-002" assert result.id == dataset.id - def test_update_internal_dataset_filter_none_values(self, db_session_with_containers): + def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session): """Test that None values are filtered out except for description field.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -278,7 +291,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description is None @@ -289,11 +302,12 @@ class TestDatasetServiceUpdateDataset: # ==================== Indexing Technique Switch Tests ==================== - def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to economy.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -312,7 +326,7 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) mock_task.delay.assert_called_once_with(dataset.id, "remove") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "economy" assert dataset.embedding_model is None assert dataset.embedding_model_provider is None @@ -320,10 +334,11 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to high_quality.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -366,7 +381,7 @@ class TestDatasetServiceUpdateDataset: mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") mock_task.delay.assert_called_once_with(dataset.id, "add") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "high_quality" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -380,9 +395,10 @@ class TestDatasetServiceUpdateDataset: self, db_session_with_containers ): """Test preserving embedding settings when indexing technique remains unchanged.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -399,7 +415,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.indexing_technique == "high_quality" @@ -409,11 +425,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session): """Test updating internal dataset with new embedding model.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -465,7 +482,7 @@ class TestDatasetServiceUpdateDataset: regenerate_vectors_only=True, ) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.embedding_model == "text-embedding-3-small" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -474,9 +491,9 @@ class TestDatasetServiceUpdateDataset: # ==================== Error Handling Tests ==================== - def test_update_dataset_not_found_error(self, db_session_with_containers): + def test_update_dataset_not_found_error(self, db_session_with_containers: Session): """Test error when dataset is not found.""" - user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) update_data = {"name": "new_name"} with pytest.raises(ValueError) as context: @@ -484,11 +501,16 @@ class TestDatasetServiceUpdateDataset: assert "Dataset not found" in str(context.value) - def test_update_dataset_permission_error(self, db_session_with_containers): + def test_update_dataset_permission_error(self, db_session_with_containers: Session): """Test error when user doesn't have permission.""" - owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, provider="vendor", @@ -500,10 +522,11 @@ class TestDatasetServiceUpdateDataset: with pytest.raises(NoPermissionError): DatasetService.update_dataset(dataset.id, update_data, outsider) - def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session): """Test error when embedding model is not available.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 93516a0030..6712fe8454 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -5,6 +5,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import Engine +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config @@ -19,7 +20,7 @@ class TestFileService: """Integration tests for FileService using testcontainers.""" @pytest.fixture - def engine(self, db_session_with_containers): + def engine(self, db_session_with_containers: Session): bind = db_session_with_containers.get_bind() assert isinstance(bind, Engine) return bind @@ -46,7 +47,7 @@ class TestFileService: "extract_processor": mock_extract_processor, } - def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account for testing. @@ -67,18 +68,16 @@ class TestFileService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -89,15 +88,15 @@ class TestFileService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test end user for testing. @@ -118,14 +117,14 @@ class TestFileService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account): + def _create_test_upload_file( + self, db_session_with_containers: Session, mock_external_service_dependencies, account + ): """ Helper method to create a test upload file for testing. @@ -155,15 +154,13 @@ class TestFileService: source_url="", ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file # Test upload_file method - def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful file upload with valid parameters. """ @@ -196,7 +193,9 @@ class TestFileService: assert upload_file.id is not None - def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_end_user( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with end user instead of account. """ @@ -219,7 +218,7 @@ class TestFileService: assert upload_file.created_by_role == CreatorUserRole.END_USER def test_upload_file_with_datasets_source( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with datasets source parameter. @@ -244,7 +243,7 @@ class TestFileService: assert upload_file.source_url == "https://example.com/source" def test_upload_file_invalid_filename_characters( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with invalid filename characters. @@ -265,7 +264,7 @@ class TestFileService: ) def test_upload_file_filename_too_long( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with filename that exceeds length limit. @@ -295,7 +294,7 @@ class TestFileService: assert len(base_name) <= 200 def test_upload_file_datasets_unsupported_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload for datasets with unsupported file type. @@ -316,7 +315,9 @@ class TestFileService: source="datasets", ) - def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_too_large( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with file size exceeding limit. """ @@ -338,7 +339,7 @@ class TestFileService: # Test is_file_size_within_limit method def test_is_file_size_within_limit_image_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files within limit. @@ -351,7 +352,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_video_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for video files within limit. @@ -364,7 +365,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_audio_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for audio files within limit. @@ -377,7 +378,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_document_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for document files within limit. @@ -390,7 +391,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_image_exceeded( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files exceeding limit. @@ -403,7 +404,7 @@ class TestFileService: assert result is False def test_is_file_size_within_limit_unknown_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for unknown file extension. @@ -416,7 +417,7 @@ class TestFileService: assert result is True # Test upload_text method - def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful text upload. """ @@ -447,7 +448,9 @@ class TestFileService: # Verify storage was called mock_external_service_dependencies["storage"].save.assert_called_once() - def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_name_too_long( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with name that exceeds length limit. """ @@ -472,7 +475,9 @@ class TestFileService: assert upload_file.name == "a" * 200 # Test get_file_preview method - def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_file_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful file preview generation. """ @@ -484,9 +489,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() result = FileService(engine).get_file_preview(file_id=upload_file.id) @@ -494,7 +498,7 @@ class TestFileService: mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() def test_get_file_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with non-existent file. @@ -506,7 +510,7 @@ class TestFileService: FileService(engine).get_file_preview(file_id=non_existent_id) def test_get_file_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with unsupported file type. @@ -519,15 +523,14 @@ class TestFileService: # Update file to have non-document extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_file_preview(file_id=upload_file.id) def test_get_file_preview_text_truncation( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with text that exceeds preview limit. @@ -540,9 +543,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock long text content long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT @@ -554,7 +556,9 @@ class TestFileService: assert result == "x" * 3000 # Test get_image_preview method - def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_image_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful image preview generation. """ @@ -566,9 +570,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -586,7 +589,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once() def test_get_image_preview_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with invalid signature. @@ -613,7 +616,7 @@ class TestFileService: ) def test_get_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-existent file. @@ -634,7 +637,7 @@ class TestFileService: ) def test_get_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-image file type. @@ -647,9 +650,8 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -665,7 +667,7 @@ class TestFileService: # Test get_file_generator_by_file_id method def test_get_file_generator_by_file_id_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful file generator retrieval. @@ -692,7 +694,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once() def test_get_file_generator_by_file_id_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with invalid signature. @@ -719,7 +721,7 @@ class TestFileService: ) def test_get_file_generator_by_file_id_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with non-existent file. @@ -741,7 +743,7 @@ class TestFileService: # Test get_public_image_preview method def test_get_public_image_preview_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful public image preview generation. @@ -754,9 +756,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id) @@ -765,7 +766,7 @@ class TestFileService: mock_external_service_dependencies["storage"].load.assert_called_once() def test_get_public_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-existent file. @@ -777,7 +778,7 @@ class TestFileService: FileService(engine).get_public_image_preview(file_id=non_existent_id) def test_get_public_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-image file type. @@ -790,15 +791,16 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_public_image_preview(file_id=upload_file.id) # Test edge cases and boundary conditions - def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_content( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty content. """ @@ -820,7 +822,7 @@ class TestFileService: assert upload_file.size == 0 def test_upload_file_special_characters_in_name( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with special characters in filename (but valid ones). @@ -843,7 +845,7 @@ class TestFileService: assert upload_file.name == filename def test_upload_file_different_case_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with different case extensions. @@ -865,7 +867,9 @@ class TestFileService: assert upload_file is not None assert upload_file.extension == "pdf" # Should be converted to lowercase - def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_empty_text( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with empty text. """ @@ -888,7 +892,9 @@ class TestFileService: assert upload_file is not None assert upload_file.size == 0 - def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_file_size_limits_edge_cases( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file size limits with edge case values. """ @@ -908,7 +914,9 @@ class TestFileService: result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size) assert result is False - def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_source_url( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with source URL that gets overridden by signed URL. """ @@ -946,7 +954,7 @@ class TestFileService: # Test file extension blacklist def test_upload_file_blocked_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension. @@ -969,7 +977,7 @@ class TestFileService: ) def test_upload_file_blocked_extension_case_insensitive( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension (case insensitive). @@ -992,7 +1000,9 @@ class TestFileService: user=account, ) - def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_not_in_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with extension not in blacklist. """ @@ -1016,7 +1026,9 @@ class TestFileService: assert upload_file.name == filename assert upload_file.extension == "pdf" - def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty blacklist (default behavior). """ @@ -1041,7 +1053,7 @@ class TestFileService: assert upload_file.extension == "sh" def test_upload_file_multiple_blocked_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with multiple blocked extensions. @@ -1066,7 +1078,7 @@ class TestFileService: ) def test_upload_file_no_extension_with_blacklist( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with no extension when blacklist is configured. diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index ece6de6cdf..19a684a58a 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.model import MessageFeedback from services.app_service import AppService @@ -69,7 +70,7 @@ class TestMessageService: # "current_user": mock_current_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -127,11 +128,10 @@ class TestMessageService: # mock_external_service_dependencies["current_user"].id = account_id # mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -153,17 +153,16 @@ class TestMessageService: from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -192,11 +191,13 @@ class TestMessageService: from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by first ID. """ @@ -204,10 +205,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by first ID @@ -228,7 +229,9 @@ class TestMessageService: # Verify messages are in ascending order assert result.data[0].created_at <= result.data[1].created_at - def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by first ID when no user is provided. """ @@ -246,7 +249,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_no_conversation_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID when no conversation ID is provided. @@ -265,7 +268,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_invalid_first_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID with invalid first_id. @@ -274,8 +277,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid first_id with pytest.raises(FirstMessageNotExistsError): @@ -287,7 +290,9 @@ class TestMessageService: limit=10, ) - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID. """ @@ -295,10 +300,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by last ID @@ -319,7 +324,7 @@ class TestMessageService: assert result.data[0].created_at >= result.data[1].created_at def test_pagination_by_last_id_with_include_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with include_ids filter. @@ -328,10 +333,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination with include_ids @@ -347,7 +352,9 @@ class TestMessageService: for message in result.data: assert message.id in include_ids - def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by last ID when no user is provided. """ @@ -363,7 +370,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_last_id_invalid_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with invalid last_id. @@ -372,8 +379,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid last_id with pytest.raises(LastMessageNotExistsError): @@ -385,7 +392,7 @@ class TestMessageService: conversation_id=conversation.id, ) - def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of feedback. """ @@ -393,8 +400,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create feedback rating = "like" @@ -413,7 +420,7 @@ class TestMessageService: assert feedback.from_account_id == account.id assert feedback.from_end_user_id is None - def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test creating feedback when no user is provided. """ @@ -421,8 +428,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no user with pytest.raises(ValueError, match="user cannot be None"): @@ -430,7 +437,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) ) - def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating existing feedback. """ @@ -438,8 +447,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback initial_rating = "like" @@ -462,7 +471,9 @@ class TestMessageService: assert updated_feedback.rating != initial_rating assert updated_feedback.content != initial_content - def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_delete_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting existing feedback by setting rating to None. """ @@ -470,8 +481,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback feedback = MessageService.create_feedback( @@ -482,13 +493,14 @@ class TestMessageService: MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None) # Verify feedback was deleted - from extensions.ext_database import db - deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + deleted_feedback = ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + ) assert deleted_feedback is None def test_create_feedback_no_rating_when_not_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating feedback with no rating when feedback doesn't exist. @@ -497,8 +509,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no rating when no feedback exists with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"): @@ -506,7 +518,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=account, rating=None, content=None ) - def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_messages_feedbacks_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all message feedbacks. """ @@ -516,8 +530,8 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks feedbacks = [] for i in range(3): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) feedback = MessageService.create_feedback( app_model=app, @@ -539,7 +553,7 @@ class TestMessageService: assert result[i]["created_at"] >= result[i + 1]["created_at"] def test_get_all_messages_feedbacks_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of message feedbacks. @@ -549,8 +563,8 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks for i in range(5): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" @@ -569,7 +583,7 @@ class TestMessageService: page_2_ids = {feedback["id"] for feedback in result_page_2} assert len(page_1_ids.intersection(page_2_ids)) == 0 - def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of message. """ @@ -577,8 +591,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Get message retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id) @@ -590,7 +604,7 @@ class TestMessageService: assert retrieved_message.from_source == "console" assert retrieved_message.from_account_id == account.id - def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message that doesn't exist. """ @@ -601,7 +615,7 @@ class TestMessageService: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4()) - def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message with wrong user (different account). """ @@ -609,8 +623,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create another account from services.account_service import AccountService, TenantService @@ -628,7 +642,7 @@ class TestMessageService: MessageService.get_message(app_model=app, user=other_account, message_id=message.id) def test_get_suggested_questions_after_answer_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful generation of suggested questions after answer. @@ -637,8 +651,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the LLMGenerator to return specific questions mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"] @@ -665,7 +679,7 @@ class TestMessageService: mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() def test_get_suggested_questions_after_answer_no_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no user is provided. @@ -674,8 +688,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test getting suggested questions with no user from core.app.entities.app_invoke_entities import InvokeFrom @@ -686,7 +700,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when feature is disabled. @@ -695,8 +709,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the feature to be disabled mock_external_service_dependencies[ @@ -712,7 +726,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_no_workflow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no workflow exists. @@ -721,8 +735,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock no workflow mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None @@ -738,7 +752,7 @@ class TestMessageService: assert result == [] def test_get_suggested_questions_after_answer_debugger_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions in debugger mode. @@ -747,8 +761,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock questions mock_questions = ["Debug question 1", "Debug question 2"] diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 5b6db64c09..6fe40c0744 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -6,9 +6,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import ( @@ -40,25 +40,25 @@ class TestMessagesCleanServiceIntegration: PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before and after each test to ensure isolation.""" yield # Clear all test data in correct order (respecting foreign key constraints) - db.session.query(DatasetRetrieverResource).delete() - db.session.query(AppAnnotationHitHistory).delete() - db.session.query(SavedMessage).delete() - db.session.query(MessageFile).delete() - db.session.query(MessageAgentThought).delete() - db.session.query(MessageChain).delete() - db.session.query(MessageAnnotation).delete() - db.session.query(MessageFeedback).delete() - db.session.query(Message).delete() - db.session.query(Conversation).delete() - db.session.query(App).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DatasetRetrieverResource).delete() + db_session_with_containers.query(AppAnnotationHitHistory).delete() + db_session_with_containers.query(SavedMessage).delete() + db_session_with_containers.query(MessageFile).delete() + db_session_with_containers.query(MessageAgentThought).delete() + db_session_with_containers.query(MessageChain).delete() + db_session_with_containers.query(MessageAnnotation).delete() + db_session_with_containers.query(MessageFeedback).delete() + db_session_with_containers.query(Message).delete() + db_session_with_containers.query(Conversation).delete() + db_session_with_containers.query(App).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() @pytest.fixture(autouse=True) def cleanup_redis(self): @@ -100,7 +100,7 @@ class TestMessagesCleanServiceIntegration: with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): yield - def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX): """Helper to create account and tenant.""" fake = Faker() @@ -110,28 +110,28 @@ class TestMessagesCleanServiceIntegration: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() tenant = Tenant( name=fake.company(), plan=str(plan), status="normal", ) - db.session.add(tenant) - db.session.flush() + db_session_with_containers.add(tenant) + db_session_with_containers.flush() tenant_account_join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER, ) - db.session.add(tenant_account_join) - db.session.commit() + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() return account, tenant - def _create_app(self, tenant, account): + def _create_app(self, db_session_with_containers: Session, tenant, account): """Helper to create an app.""" fake = Faker() @@ -149,12 +149,12 @@ class TestMessagesCleanServiceIntegration: created_by=account.id, updated_by=account.id, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_conversation(self, app): + def _create_conversation(self, db_session_with_containers: Session, app): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, @@ -168,12 +168,14 @@ class TestMessagesCleanServiceIntegration: from_source="api", from_end_user_id=str(uuid.uuid4()), ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def _create_message(self, app, conversation, created_at=None, with_relations=True): + def _create_message( + self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True + ): """Helper to create a message with optional related records.""" if created_at is None: created_at = datetime.datetime.now() @@ -197,16 +199,16 @@ class TestMessagesCleanServiceIntegration: from_account_id=conversation.from_end_user_id, created_at=created_at, ) - db.session.add(message) - db.session.flush() + db_session_with_containers.add(message) + db_session_with_containers.flush() if with_relations: - self._create_message_relations(message) + self._create_message_relations(db_session_with_containers, message) - db.session.commit() + db_session_with_containers.commit() return message - def _create_message_relations(self, message): + def _create_message_relations(self, db_session_with_containers: Session, message): """Helper to create all message-related records.""" # MessageFeedback feedback = MessageFeedback( @@ -217,7 +219,7 @@ class TestMessagesCleanServiceIntegration: from_source="api", from_end_user_id=str(uuid.uuid4()), ) - db.session.add(feedback) + db_session_with_containers.add(feedback) # MessageAnnotation annotation = MessageAnnotation( @@ -228,7 +230,7 @@ class TestMessagesCleanServiceIntegration: content="Test annotation", account_id=message.from_account_id, ) - db.session.add(annotation) + db_session_with_containers.add(annotation) # MessageChain chain = MessageChain( @@ -237,8 +239,8 @@ class TestMessagesCleanServiceIntegration: input=json.dumps({"test": "input"}), output=json.dumps({"test": "output"}), ) - db.session.add(chain) - db.session.flush() + db_session_with_containers.add(chain) + db_session_with_containers.flush() # MessageFile file = MessageFile( @@ -250,7 +252,7 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(file) + db_session_with_containers.add(file) # SavedMessage saved = SavedMessage( @@ -259,9 +261,9 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(saved) + db_session_with_containers.add(saved) - db.session.flush() + db_session_with_containers.flush() # AppAnnotationHitHistory hit = AppAnnotationHitHistory( @@ -275,7 +277,7 @@ class TestMessagesCleanServiceIntegration: annotation_question="Test annotation question", annotation_content="Test annotation content", ) - db.session.add(hit) + db_session_with_containers.add(hit) # DatasetRetrieverResource resource = DatasetRetrieverResource( @@ -296,25 +298,29 @@ class TestMessagesCleanServiceIntegration: retriever_from="dataset", created_by=message.from_account_id, ) - db.session.add(resource) + db_session_with_containers.add(resource) def test_billing_disabled_deletes_all_messages_in_time_range( - self, db_session_with_containers, mock_billing_disabled + self, db_session_with_containers: Session, mock_billing_disabled ): """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: in-range (should be deleted) and out-of-range (should be kept) in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) - in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True + ) in_range_msg_id = in_range_msg.id - out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True + ) out_of_range_msg_id = out_of_range_msg.id # Act - create_message_clean_policy should return BillingDisabledPolicy @@ -336,17 +342,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # In-range message deleted - assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0 # Out-of-range message kept - assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 # Related records of in-range message deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 - assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == in_range_msg_id) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id == in_range_msg_id) + .count() + == 0 + ) # Related records of out-of-range message kept - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == out_of_range_msg_id) + .count() + == 1 + ) - def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_no_messages_returns_empty_stats( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning when there are no messages to delete (B1).""" # Arrange end_before = datetime.datetime.now() - datetime.timedelta(days=30) @@ -371,36 +394,42 @@ class TestMessagesCleanServiceIntegration: assert stats["filtered_messages"] == 0 assert stats["total_deleted"] == 0 - def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_mixed_sandbox_and_paid_tenants( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning with mixed sandbox and paid tenants (B2).""" # Arrange - Create sandbox tenants with expired messages sandbox_tenants = [] sandbox_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) sandbox_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 3 expired messages per sandbox tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) sandbox_message_ids.append(msg.id) # Create paid tenants with expired messages (should NOT be deleted) paid_tenants = [] paid_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) paid_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 2 expired messages per paid tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(2): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) paid_message_ids.append(msg.id) # Mock billing service - return plan and expiration_date @@ -442,29 +471,39 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 6 # Only sandbox messages should be deleted - assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 # Paid messages should remain - assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 # Related records of sandbox messages should be deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 assert ( - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id.in_(sandbox_message_ids)) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id.in_(sandbox_message_ids)) + .count() == 0 ) - def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_cursor_pagination_multiple_batches( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cursor pagination works correctly across multiple batches (B3).""" # Arrange - Create sandbox tenant with messages that will span multiple batches - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 10 expired messages with different timestamps base_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(10): msg = self._create_message( + db_session_with_containers, app, conv, created_at=base_date + datetime.timedelta(hours=i), @@ -498,20 +537,22 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 10 # All messages should be deleted - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0 - def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test dry_run mode does not delete messages (B4).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create expired messages expired_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) message_ids.append(msg.id) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -540,21 +581,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # But NOT deleted # All messages should still exist - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3 # Related records should also still exist - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + assert ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() + == 3 + ) - def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_partial_plan_data_safe_default( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns partial data, unknown tenants are preserved (B5).""" # Arrange - Create 3 tenants tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) tenants_data.append( { @@ -600,28 +646,30 @@ class TestMessagesCleanServiceIntegration: # Check which messages were deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 ) # Sandbox tenant's message deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 ) # Professional tenant's message preserved assert ( - db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 ) # Unknown tenant's message preserved (safe default) - def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_empty_plan_data_skips_deletion( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns empty data, skip deletion entirely (B6).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) msg_id = msg.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service to return empty data (simulating failure/no data scenario) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -644,17 +692,20 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Message should still exist (safe default - don't delete if plan is unknown) - assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1 - def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_time_range_boundary_behavior( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: before range, in range, after range msg_before = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from @@ -663,6 +714,7 @@ class TestMessagesCleanServiceIntegration: msg_before_id = msg_before.id msg_at_start = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) @@ -671,6 +723,7 @@ class TestMessagesCleanServiceIntegration: msg_at_start_id = msg_at_start.id msg_in_range = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range @@ -679,6 +732,7 @@ class TestMessagesCleanServiceIntegration: msg_in_range_id = msg_in_range.id msg_at_end = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) @@ -687,6 +741,7 @@ class TestMessagesCleanServiceIntegration: msg_at_end_id = msg_at_end.id msg_after = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before @@ -694,7 +749,7 @@ class TestMessagesCleanServiceIntegration: ) msg_after_id = msg_after.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -722,17 +777,17 @@ class TestMessagesCleanServiceIntegration: # Verify specific messages using stored IDs # Before range, kept - assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1 # At start (inclusive), deleted - assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0 # In range, deleted - assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0 # At end (exclusive), kept - assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1 # After range, kept - assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1 - def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test cleaning with different graceful period scenarios (B8).""" # Arrange - Create 5 different tenants with different plan and expiration scenarios now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) @@ -740,50 +795,60 @@ class TestMessagesCleanServiceIntegration: # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) # Should NOT be deleted - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) msg1_id = msg1.id expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) # Should be deleted - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) msg2_id = msg2.id expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) # Should be deleted - account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app3 = self._create_app(tenant3, account3) - conv3 = self._create_conversation(app3) - msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app3 = self._create_app(db_session_with_containers, tenant3, account3) + conv3 = self._create_conversation(db_session_with_containers, app3) + msg3 = self._create_message( + db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False + ) msg3_id = msg3.id # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) # Should NOT be deleted - account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) - app4 = self._create_app(tenant4, account4) - conv4 = self._create_conversation(app4) - msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(db_session_with_containers, tenant4, account4) + conv4 = self._create_conversation(db_session_with_containers, app4) + msg4 = self._create_message( + db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False + ) msg4_id = msg4.id future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) # Should NOT be deleted (boundary is exclusive: > graceful_period) - account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app5 = self._create_app(tenant5, account5) - conv5 = self._create_conversation(app5) - msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app5 = self._create_app(db_session_with_containers, tenant5, account5) + conv5 = self._create_conversation(db_session_with_containers, app5) + msg5 = self._create_message( + db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False + ) msg5_id = msg5.id expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary - db.session.commit() + db_session_with_containers.commit() # Mock billing service with all scenarios plan_map = { @@ -832,23 +897,31 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 2 # Verify each scenario using saved IDs - assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept - assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted - assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted - assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept - assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0 + ) # Beyond grace, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0 + ) # No subscription, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1 + ) # Professional plan, kept + assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept - def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test that whitelisted tenants' messages are not deleted (B9).""" # Arrange - Create 3 sandbox tenants with expired messages tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date, with_relations=False + ) tenants_data.append( { @@ -897,27 +970,33 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # Verify tenant0's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 # Verify tenant1's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 # Verify tenant2's message was deleted (not whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 - def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_from_days_cleans_old_messages( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test from_days correctly cleans messages older than N days (B11).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create old messages (should be deleted - older than 30 days) old_date = datetime.datetime.now() - datetime.timedelta(days=45) old_msg_ids = [] for i in range(3): msg = self._create_message( - app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=old_date - datetime.timedelta(hours=i), + with_relations=False, ) old_msg_ids.append(msg.id) @@ -926,11 +1005,15 @@ class TestMessagesCleanServiceIntegration: recent_msg_ids = [] for i in range(2): msg = self._create_message( - app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=recent_date - datetime.timedelta(hours=i), + with_relations=False, ) recent_msg_ids.append(msg.id) - db.session.commit() + db_session_with_containers.commit() with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: mock_billing.return_value = { @@ -955,30 +1038,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Old messages deleted - assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 # Recent messages kept - assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 def test_whitelist_precedence_over_grace_period( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that whitelist takes precedence over grace period logic.""" # Arrange - Create 2 sandbox tenants now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) # Tenant1: whitelisted, expired beyond grace period - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace # Tenant2: not whitelisted, within grace period - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace # Mock billing service @@ -1019,22 +1106,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Verify both messages still exist - assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted - assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1 + ) # Within grace period def test_empty_whitelist_deletes_eligible_messages( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" # Arrange - Create sandbox tenant with expired messages - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) msg_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) msg_ids.append(msg.id) # Mock billing service @@ -1068,4 +1159,4 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Verify all messages were deleted - assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e04725627b..694dc1c1b9 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -32,7 +33,7 @@ class TestMetadataService: "document_service": mock_document_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -53,18 +54,16 @@ class TestMetadataService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -73,15 +72,17 @@ class TestMetadataService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + def _create_test_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant + ): """ Helper method to create a test dataset for testing. @@ -105,14 +106,14 @@ class TestMetadataService: built_in_field_enabled=False, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + def _create_test_document( + self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account + ): """ Helper method to create a test document for testing. @@ -141,14 +142,12 @@ class TestMetadataService: doc_language="en", ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata creation with valid parameters. """ @@ -178,13 +177,14 @@ class TestMetadataService: assert result.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None - def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name exceeds 255 characters. """ @@ -207,7 +207,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name already exists in the same dataset. """ @@ -235,7 +237,7 @@ class TestMetadataService: MetadataService.create_metadata(dataset.id, second_metadata_args) def test_create_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata creation fails when name conflicts with built-in field names. @@ -260,7 +262,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful metadata name update with valid parameters. """ @@ -291,12 +295,13 @@ class TestMetadataService: assert result.updated_at is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == new_name - def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name exceeds 255 characters. """ @@ -323,7 +328,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) - def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name already exists in the same dataset. """ @@ -351,7 +358,7 @@ class TestMetadataService: MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") def test_update_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata name update fails when new name conflicts with built-in field names. @@ -378,7 +385,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) - def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when metadata ID does not exist. """ @@ -406,7 +415,7 @@ class TestMetadataService: # Assert: Verify the method returns None when metadata is not found assert result is None - def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata deletion with valid parameters. """ @@ -434,12 +443,11 @@ class TestMetadataService: assert result.id == metadata.id # Verify metadata was deleted from database - from extensions.ext_database import db - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None - def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test metadata deletion fails when metadata ID does not exist. """ @@ -467,7 +475,7 @@ class TestMetadataService: assert result is None def test_delete_metadata_with_document_bindings( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata deletion successfully removes document metadata bindings. @@ -500,15 +508,13 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Set document metadata document.doc_metadata = {"test_metadata": "test_value"} - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.delete_metadata(dataset.id, metadata.id) @@ -517,13 +523,13 @@ class TestMetadataService: assert result is not None # Verify metadata was deleted from database - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None # Note: The service attempts to update document metadata but may not succeed # due to mock configuration. The main functionality (metadata deletion) is verified. - def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of built-in metadata fields. """ @@ -548,7 +554,9 @@ class TestMetadataService: assert "string" in field_types assert "time" in field_types - def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of built-in fields for a dataset. """ @@ -579,16 +587,15 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (enabling built-in fields) is verified def test_enable_built_in_field_already_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields when they are already enabled. @@ -607,10 +614,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -619,11 +625,11 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the method returns early without changes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True def test_enable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields for a dataset with no documents. @@ -647,12 +653,13 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True - def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of built-in fields for a dataset. """ @@ -673,10 +680,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Set document metadata with built-in fields document.doc_metadata = { @@ -686,8 +692,8 @@ class TestMetadataService: BuiltInField.last_update_date: 1234567890.0, BuiltInField.source: "test_source", } - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ @@ -698,14 +704,14 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (disabling built-in fields) is verified def test_disable_built_in_field_already_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields when they are already disabled. @@ -732,13 +738,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the method returns early without changes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False def test_disable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields for a dataset with no documents. @@ -757,10 +762,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id to return empty list mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -769,10 +773,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False - def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_documents_metadata_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of documents metadata. """ @@ -815,24 +821,25 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify document metadata was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" # Verify metadata binding was created binding = ( - db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + db_session_with_containers.query(DatasetMetadataBinding) + .filter_by(metadata_id=metadata.id, document_id=document.id) + .first() ) assert binding is not None assert binding.tenant_id == tenant.id assert binding.dataset_id == dataset.id def test_update_documents_metadata_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when built-in fields are enabled. @@ -850,10 +857,9 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id @@ -884,7 +890,7 @@ class TestMetadataService: # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" @@ -893,7 +899,7 @@ class TestMetadataService: # The main functionality (custom metadata update) is verified def test_update_documents_metadata_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when document is not found. @@ -936,7 +942,7 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) def test_knowledge_base_metadata_lock_check_dataset_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for dataset operations. @@ -959,7 +965,7 @@ class TestMetadataService: assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" def test_knowledge_base_metadata_lock_check_document_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for document operations. @@ -982,7 +988,7 @@ class TestMetadataService: assert call_args[0][0] == f"document_metadata_lock_{document_id}" def test_knowledge_base_metadata_lock_check_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when lock already exists. @@ -999,7 +1005,7 @@ class TestMetadataService: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) def test_knowledge_base_metadata_lock_check_document_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when document lock already exists. @@ -1013,7 +1019,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): MetadataService.knowledge_base_metadata_lock_check(None, document_id) - def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of dataset metadata information. """ @@ -1046,10 +1054,8 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.get_dataset_metadatas(dataset) @@ -1071,7 +1077,7 @@ class TestMetadataService: assert result["built_in_field_enabled"] is False def test_get_dataset_metadatas_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of dataset metadata when built-in fields are enabled. @@ -1086,10 +1092,9 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id @@ -1114,7 +1119,9 @@ class TestMetadataService: # Verify built-in field status assert result["built_in_field_enabled"] is True - def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_no_metadata( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of dataset metadata when no metadata exists. """ diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 7c8472e819..989df42499 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -67,7 +68,7 @@ class TestModelLoadBalancingService: "credential_schema": mock_credential_schema, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -88,18 +89,16 @@ class TestModelLoadBalancingService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -108,8 +107,8 @@ class TestModelLoadBalancingService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -117,7 +116,7 @@ class TestModelLoadBalancingService: return account, tenant def _create_test_provider_and_setting( - self, db_session_with_containers, tenant_id, mock_external_service_dependencies + self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies ): """ Helper method to create a test provider and provider model setting. @@ -132,8 +131,6 @@ class TestModelLoadBalancingService: """ fake = Faker() - from extensions.ext_database import db - # Create provider provider = Provider( tenant_id=tenant_id, @@ -141,8 +138,8 @@ class TestModelLoadBalancingService: provider_type="custom", is_valid=True, ) - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Create provider model setting provider_model_setting = ProviderModelSetting( @@ -153,12 +150,14 @@ class TestModelLoadBalancingService: enabled=True, load_balancing_enabled=False, ) - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider, provider_model_setting - def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing enablement. @@ -193,14 +192,15 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None - def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing disablement. @@ -235,15 +235,14 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None def test_enable_model_load_balancing_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist. @@ -275,11 +274,12 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() - def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_load_balancing_configs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of load balancing configurations. @@ -298,7 +298,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -309,11 +308,11 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Verify the config was created - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Setup mocks for get_load_balancing_configs method @@ -358,11 +357,11 @@ class TestModelLoadBalancingService: assert configs[0]["ttl"] == 0 # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None def test_get_load_balancing_configs_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist in get_load_balancing_configs. @@ -394,12 +393,11 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() def test_get_load_balancing_configs_with_inherit_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test load balancing configs retrieval with inherit configuration. @@ -419,7 +417,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -430,8 +427,8 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Setup mocks for inherit config scenario mock_provider_config = mock_external_service_dependencies["provider_config"] @@ -467,11 +464,11 @@ class TestModelLoadBalancingService: assert configs[1]["name"] == "config1" # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = db.session.scalars( + inherit_configs = db_session_with_containers.scalars( select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 7a4662055c..6afc5aa43c 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType @@ -29,7 +30,7 @@ class TestModelProviderService: "model_provider_factory": mock_model_provider_factory, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestModelProviderService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestModelProviderService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -80,7 +79,7 @@ class TestModelProviderService: def _create_test_provider( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str = "openai", @@ -109,16 +108,14 @@ class TestModelProviderService: quota_used=0, ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider def _create_test_provider_model( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -149,16 +146,14 @@ class TestModelProviderService: is_valid=True, ) - from extensions.ext_database import db - - db.session.add(provider_model) - db.session.commit() + db_session_with_containers.add(provider_model) + db_session_with_containers.commit() return provider_model def _create_test_provider_model_setting( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -190,14 +185,12 @@ class TestModelProviderService: load_balancing_enabled=False, ) - from extensions.ext_database import db - - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider_model_setting - def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful provider list retrieval. @@ -275,7 +268,7 @@ class TestModelProviderService: mock_provider_config.is_custom_configuration_available.assert_called_once() def test_get_provider_list_with_model_type_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test provider list retrieval with model type filtering. @@ -374,7 +367,9 @@ class TestModelProviderService: assert result[0].provider == "cohere" assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types - def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by provider. @@ -485,7 +480,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_configurations.get_models.assert_called_once_with(provider="openai") - def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of provider credentials. @@ -543,7 +540,7 @@ class TestModelProviderService: mock_method.assert_called_once_with(tenant.id, "openai") def test_provider_credentials_validate_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful validation of provider credentials. @@ -585,7 +582,7 @@ class TestModelProviderService: mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) def test_provider_credentials_validate_invalid_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation failure for non-existent provider. @@ -617,7 +614,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) def test_get_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of default model for a specific model type. @@ -673,7 +670,7 @@ class TestModelProviderService: mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM) def test_update_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of default model for a specific model type. @@ -706,7 +703,9 @@ class TestModelProviderService: tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4" ) - def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_provider_icon_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model provider icon. @@ -743,7 +742,9 @@ class TestModelProviderService: # Verify mock interactions mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") - def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_preferred_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful switching of preferred provider type. @@ -779,7 +780,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.switch_preferred_provider_type.assert_called_once() - def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful enabling of a model. @@ -815,7 +816,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4") - def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model credentials. @@ -872,7 +875,9 @@ class TestModelProviderService: # Verify the method was called with correct parameters mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) - def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_model_credentials_validate_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful validation of model credentials. @@ -914,7 +919,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials ) - def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful saving of model credentials. @@ -955,7 +962,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" ) - def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful removal of model credentials. @@ -993,7 +1002,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" ) - def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_model_type_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by model type. @@ -1070,7 +1081,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) - def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_parameter_rules_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model parameter rules. @@ -1137,7 +1150,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when no credentials are available. @@ -1181,7 +1194,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when provider does not exist. diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 9e6b9837ae..e3ec1d1df3 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.model import EndUser, Message from models.web import SavedMessage @@ -38,7 +39,7 @@ class TestSavedMessageService: "message_service": mock_message_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -85,7 +86,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -108,14 +109,12 @@ class TestSavedMessageService: is_anonymous=False, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_message(self, db_session_with_containers, app, user): + def _create_test_message(self, db_session_with_containers: Session, app, user): """ Helper method to create a test message for testing. @@ -143,10 +142,8 @@ class TestSavedMessageService: mode="chat", ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -168,13 +165,13 @@ class TestSavedMessageService: status="success", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_pagination_by_last_id_success_with_account_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with account user. @@ -207,10 +204,8 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -240,15 +235,15 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "account" assert saved_message2.created_by_role == "account" def test_pagination_by_last_id_success_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with end user. @@ -282,10 +277,8 @@ class TestSavedMessageService: created_by=end_user.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -317,14 +310,16 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "end_user" assert saved_message2.created_by_role == "end_user" - def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_success_with_new_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful save of a new message. @@ -347,10 +342,9 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was created in database - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -373,10 +367,12 @@ class TestSavedMessageService: ) # Verify database state - db.session.refresh(saved_message) + db_session_with_containers.refresh(saved_message) assert saved_message.id is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -396,12 +392,11 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) # Verify no database operations were performed - from extensions.ext_database import db - saved_messages = db.session.query(SavedMessage).all() + saved_messages = db_session_with_containers.query(SavedMessage).all() assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -422,10 +417,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -435,7 +429,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -457,14 +453,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -481,7 +475,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -494,11 +488,13 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -522,7 +518,7 @@ class TestSavedMessageService: # Instead, we verify that the error was properly raised pass - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -543,10 +539,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -556,7 +551,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -578,14 +575,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -602,7 +597,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -615,6 +610,6 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index e8c7f17e0b..597ba6b75b 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -4,6 +4,7 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestTagService: "current_user": mock_current_user, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestTagService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestTagService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -82,7 +81,7 @@ class TestTagService: return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test dataset for testing. @@ -107,14 +106,12 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test app for testing. @@ -141,15 +138,13 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app def _create_test_tags( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3 ): """ Helper method to create test tags for testing. @@ -176,16 +171,14 @@ class TestTagService: ) tags.append(tag) - from extensions.ext_database import db - for tag in tags: - db.session.add(tag) - db.session.commit() + db_session_with_containers.add(tag) + db_session_with_containers.commit() return tags def _create_test_tag_bindings( - self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id ): """ Helper method to create test tag bindings for testing. @@ -211,15 +204,13 @@ class TestTagService: ) tag_bindings.append(tag_binding) - from extensions.ext_database import db - for tag_binding in tag_bindings: - db.session.add(tag_binding) - db.session.commit() + db_session_with_containers.add(tag_binding) + db_session_with_containers.commit() return tag_bindings - def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags with binding count. @@ -270,7 +261,9 @@ class TestTagService: # The ordering is handled by the database, we just verify the results are returned assert len(result) == 3 - def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_with_keyword_filter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval with keyword filtering. @@ -291,12 +284,11 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_development" tags[1].name = "machine_learning" tags[2].name = "web_development" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword filter result = TagService.get_tags("app", tenant.id, keyword="development") @@ -314,7 +306,7 @@ class TestTagService: assert len(result_no_match) == 0 def test_get_tags_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test tag retrieval with special characters in keyword to verify SQL injection prevention. @@ -330,8 +322,6 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - from extensions.ext_database import db - # Create tags with special characters in names tag_with_percent = Tag( name="50% discount", @@ -340,7 +330,7 @@ class TestTagService: created_by=account.id, ) tag_with_percent.id = str(uuid.uuid4()) - db.session.add(tag_with_percent) + db_session_with_containers.add(tag_with_percent) tag_with_underscore = Tag( name="test_data_tag", @@ -349,7 +339,7 @@ class TestTagService: created_by=account.id, ) tag_with_underscore.id = str(uuid.uuid4()) - db.session.add(tag_with_underscore) + db_session_with_containers.add(tag_with_underscore) tag_with_backslash = Tag( name="path\\to\\tag", @@ -358,7 +348,7 @@ class TestTagService: created_by=account.id, ) tag_with_backslash.id = str(uuid.uuid4()) - db.session.add(tag_with_backslash) + db_session_with_containers.add(tag_with_backslash) # Create tag that should NOT match tag_no_match = Tag( @@ -368,9 +358,9 @@ class TestTagService: created_by=account.id, ) tag_no_match.id = str(uuid.uuid4()) - db.session.add(tag_no_match) + db_session_with_containers.add(tag_no_match) - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Test 1 - Search with % character result = TagService.get_tags("app", tenant.id, keyword="50%") @@ -392,7 +382,7 @@ class TestTagService: assert len(result) == 1 assert all("50%" in item.name for item in result) - def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. @@ -414,7 +404,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_target_ids_by_tag_ids_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of target IDs by tag IDs. @@ -469,7 +461,7 @@ class TestTagService: assert second_dataset_count == 1 def test_get_target_ids_by_tag_ids_empty_tag_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval with empty tag IDs list. @@ -493,7 +485,7 @@ class TestTagService: assert isinstance(result, list) def test_get_target_ids_by_tag_ids_no_matching_tags( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval when no tags match the criteria. @@ -521,7 +513,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags by tag name. @@ -542,11 +534,10 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_tag" tags[1].name = "ml_tag" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") @@ -558,7 +549,9 @@ class TestTagService: assert result[0].type == "app" assert result[0].tenant_id == tenant.id - def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_no_matches( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name when no matches exist. @@ -580,7 +573,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_empty_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name with empty parameters. @@ -605,7 +600,9 @@ class TestTagService: assert result_empty_name is not None assert len(result_empty_name) == 0 - def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tags by target ID. @@ -644,7 +641,9 @@ class TestTagService: assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] - def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_no_bindings( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by target ID when no tags are bound. @@ -669,7 +668,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag creation. @@ -698,17 +697,18 @@ class TestTagService: assert result.id is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None # Verify tag was actually saved to database - saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() + saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first() assert saved_tag is not None assert saved_tag.name == "test_tag_name" - def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag creation with duplicate name. @@ -731,7 +731,7 @@ class TestTagService: TagService.save_tags(tag_args) assert "Tag name already exists" in str(exc_info.value) - def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag update. @@ -763,17 +763,16 @@ class TestTagService: assert result.id == tag.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == "updated_name" # Verify tag was actually updated in database - updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() + updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert updated_tag is not None assert updated_tag.name == "updated_name" - def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag update for non-existent tag. @@ -799,7 +798,9 @@ class TestTagService: TagService.update_tags(update_args, non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag update with duplicate name. @@ -828,7 +829,9 @@ class TestTagService: TagService.update_tags(update_args, tag2.id) assert "Tag name already exists" in str(exc_info.value) - def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_binding_count_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tag binding count. @@ -863,7 +866,7 @@ class TestTagService: assert result_tag_without_bindings == 0 def test_get_tag_binding_count_non_existent_tag( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test binding count retrieval for non-existent tag. @@ -889,7 +892,7 @@ class TestTagService: # Assert: Verify the expected outcomes assert result == 0 - def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag deletion. @@ -916,12 +919,11 @@ class TestTagService: ) # Verify tag and binding exist before deletion - from extensions.ext_database import db - tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_before is not None - binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_before is not None # Act: Execute the method under test @@ -929,14 +931,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag was deleted - tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_after is None # Verify tag binding was deleted - binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_after is None - def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag deletion for non-existent tag. @@ -960,7 +962,7 @@ class TestTagService: TagService.delete_tag(non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding creation. @@ -988,12 +990,11 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify tag bindings were created for tag in tags: binding = ( - db.session.query(TagBinding) + db_session_with_containers.query(TagBinding) .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) .first() ) @@ -1001,7 +1002,9 @@ class TestTagService: assert binding.tenant_id == tenant.id assert binding.created_by == account.id - def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_duplicate_handling( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with duplicate bindings. @@ -1032,15 +1035,16 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 1 - def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_invalid_target_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with invalid target type. @@ -1071,7 +1075,7 @@ class TestTagService: TagService.save_tag_binding(binding_args) assert "Invalid binding type" in str(exc_info.value) - def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding deletion. @@ -1098,10 +1102,11 @@ class TestTagService: ) # Verify binding exists before deletion - from extensions.ext_database import db binding_before = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_before is not None @@ -1112,12 +1117,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag binding was deleted binding_after = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_after is None def test_delete_tag_binding_non_existent_binding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tag binding deletion for non-existent binding. @@ -1145,15 +1152,14 @@ class TestTagService: # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged - from extensions.ext_database import db - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful target existence check for knowledge type. @@ -1179,7 +1185,7 @@ class TestTagService: # No exception should be raised for existing dataset def test_check_target_exists_knowledge_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target existence check for non-existent knowledge dataset. @@ -1204,7 +1210,9 @@ class TestTagService: TagService.check_target_exists("knowledge", non_existent_dataset_id) assert "Dataset not found" in str(exc_info.value) - def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful target existence check for app type. @@ -1228,7 +1236,9 @@ class TestTagService: # Assert: Verify the expected outcomes # No exception should be raised for existing app - def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for non-existent app. @@ -1252,7 +1262,9 @@ class TestTagService: TagService.check_target_exists("app", non_existent_app_id) assert "App not found" in str(exc_info.value) - def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_invalid_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for invalid type. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 5315960d73..912aa3dd2f 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity -from extensions.ext_database import db from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService @@ -47,7 +47,7 @@ class TestTriggerProviderService: "account_feature_service": mock_account_feature_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -84,7 +84,7 @@ class TestTriggerProviderService: def _create_test_subscription( self, - db_session_with_containers, + db_session_with_containers: Session, tenant_id, user_id, provider_id, @@ -135,14 +135,14 @@ class TestTriggerProviderService: expires_at=-1, ) - db.session.add(subscription) - db.session.commit() - db.session.refresh(subscription) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + db_session_with_containers.refresh(subscription) return subscription def test_rebuild_trigger_subscription_success_with_merged_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful rebuild with credential merging (HIDDEN_VALUE handling). @@ -217,7 +217,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "new-secret-value" # New value # Verify database state was updated - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == "updated_name" assert subscription.parameters == {"param1": "updated_value"} @@ -244,7 +244,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_with_all_new_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are new (no HIDDEN_VALUE). @@ -304,7 +304,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "completely-new-secret" def test_rebuild_trigger_subscription_with_all_hidden_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). @@ -363,7 +363,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. @@ -422,7 +422,7 @@ class TestTriggerProviderService: assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE def test_rebuild_trigger_subscription_rollback_on_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that transaction is rolled back on error. @@ -470,12 +470,12 @@ class TestTriggerProviderService: ) # Verify subscription state was not changed (rolled back) - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == original_name assert subscription.parameters == original_parameters def test_rebuild_trigger_subscription_subscription_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error when subscription is not found. @@ -501,7 +501,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_name_uniqueness_check( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that name uniqueness is checked when updating name. diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index bbbf48ede9..f1e8c152f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models import Account @@ -45,7 +46,7 @@ class TestWebConversationService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -90,7 +91,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -111,14 +112,12 @@ class TestWebConversationService: tenant_id=app.tenant_id, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_conversation(self, db_session_with_containers, app, user, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake): """ Helper method to create a test conversation for testing. @@ -152,14 +151,14 @@ class TestWebConversationService: is_deleted=False, ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID with basic parameters. """ @@ -194,7 +193,7 @@ class TestWebConversationService: assert result.data[1].updated_at >= result.data[2].updated_at def test_pagination_by_last_id_with_pinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with pinned conversation filter. @@ -222,11 +221,9 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation1) - db.session.add(pinned_conversation2) - db.session.commit() + db_session_with_containers.add(pinned_conversation1) + db_session_with_containers.add(pinned_conversation2) + db_session_with_containers.commit() # Test pagination with pinned filter result = WebConversationService.pagination_by_last_id( @@ -251,7 +248,7 @@ class TestWebConversationService: assert set(returned_ids) == set(expected_ids) def test_pagination_by_last_id_with_unpinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with unpinned conversation filter. @@ -273,10 +270,8 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation) - db.session.commit() + db_session_with_containers.add(pinned_conversation) + db_session_with_containers.commit() # Test pagination with unpinned filter result = WebConversationService.pagination_by_last_id( @@ -303,7 +298,7 @@ class TestWebConversationService: expected_unpinned_ids = [conv.id for conv in conversations[1:]] assert set(returned_ids) == set(expected_unpinned_ids) - def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful pinning of a conversation. """ @@ -317,10 +312,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -336,7 +330,9 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "account" assert pinned_conversation.created_by == account.id - def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_already_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation that is already pinned (should not create duplicate). """ @@ -353,9 +349,8 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify only one pinned conversation record exists - from extensions.ext_database import db - pinned_conversations = db.session.scalars( + pinned_conversations = db_session_with_containers.scalars( select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -366,7 +361,9 @@ class TestWebConversationService: assert len(pinned_conversations) == 1 - def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation with an end user. """ @@ -383,10 +380,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, end_user) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -402,7 +398,7 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "end_user" assert pinned_conversation.created_by == end_user.id - def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful unpinning of a conversation. """ @@ -416,10 +412,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -436,7 +431,7 @@ class TestWebConversationService: # Verify it was unpinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -448,7 +443,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_not_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test unpinning a conversation that is not pinned (should not cause error). """ @@ -462,10 +459,9 @@ class TestWebConversationService: WebConversationService.unpin(app, conversation.id, account) # Verify no pinned conversation record exists - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -478,7 +474,7 @@ class TestWebConversationService: assert pinned_conversation is None def test_pagination_by_last_id_user_required_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that pagination_by_last_id raises ValueError when user is None. @@ -499,7 +495,7 @@ class TestWebConversationService: sort_by="-updated_at", ) - def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that pin method returns early when user is None. """ @@ -513,10 +509,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, None) # Verify no pinned conversation was created - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -526,7 +521,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_user_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that unpin method returns early when user is None. """ @@ -540,10 +537,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -560,7 +556,7 @@ class TestWebConversationService: # Verify the conversation is still pinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index d1c566e477..9a1595d266 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized from libs.password import hash_password @@ -45,7 +46,7 @@ class TestWebAppAuthService: "enterprise_service": mock_enterprise_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -68,18 +69,16 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,15 +87,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_with_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test account with password for testing. @@ -131,18 +132,16 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -151,15 +150,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant, password - def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): + def _create_test_app_and_site( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant + ): """ Helper method to create a test app and site for testing. @@ -188,10 +189,8 @@ class TestWebAppAuthService: enable_api=True, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create site site = Site( @@ -203,12 +202,12 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() return app, site - def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful authentication with valid email and password. @@ -233,14 +232,15 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.password is not None assert result.password_salt is not None - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent email. @@ -262,7 +262,7 @@ class TestWebAppAuthService: with pytest.raises(AccountNotFoundError): WebAppAuthService.authenticate(non_existent_email, "any_password") - def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. @@ -292,10 +292,8 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountLoginError) as exc_info: @@ -303,7 +301,9 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) - def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_invalid_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invalid password. @@ -323,7 +323,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) def test_authenticate_account_without_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication for account without password. @@ -345,10 +345,8 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountPasswordError) as exc_info: @@ -356,7 +354,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login and JWT token generation. @@ -388,7 +386,9 @@ class TestWebAppAuthService: assert call_args["auth_type"] == "internal" assert "exp" in call_args - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful user retrieval through email. @@ -413,12 +413,13 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with non-existent email. @@ -435,7 +436,9 @@ class TestWebAppAuthService: # Assert: Verify proper handling assert result is None - def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_banned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with banned account. @@ -456,10 +459,8 @@ class TestWebAppAuthService: status=AccountStatus.BANNED, ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(Unauthorized) as exc_info: @@ -468,7 +469,7 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) def test_send_email_code_login_email_with_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with account. @@ -509,7 +510,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_with_email_only( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with email only. @@ -549,7 +550,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_no_email_provided( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email without providing email. @@ -566,7 +567,9 @@ class TestWebAppAuthService: assert "Email must be provided." in str(exc_info.value) - def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of email code login data. @@ -593,7 +596,9 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_no_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email code login data retrieval when no data exists. @@ -617,7 +622,7 @@ class TestWebAppAuthService: ) def test_revoke_email_code_login_token_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful revocation of email code login token. @@ -636,7 +641,7 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful end user creation. @@ -668,14 +673,15 @@ class TestWebAppAuthService: assert result.external_user_id == "enterpriseuser" # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None assert result.updated_at is not None - def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_site_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation with non-existent site code. @@ -693,7 +699,9 @@ class TestWebAppAuthService: assert "Site not found." in str(exc_info.value) - def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation when app is not found. @@ -708,10 +716,8 @@ class TestWebAppAuthService: status="normal", ) - from extensions.ext_database import db - - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() site = Site( app_id="00000000-0000-0000-0000-000000000000", @@ -722,8 +728,8 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -732,7 +738,7 @@ class TestWebAppAuthService: assert "App not found." in str(exc_info.value) def test_is_app_require_permission_check_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for private access mode. @@ -751,7 +757,7 @@ class TestWebAppAuthService: assert result is True def test_is_app_require_permission_check_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for public access mode. @@ -770,7 +776,7 @@ class TestWebAppAuthService: assert result is False def test_is_app_require_permission_check_with_app_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement using app code. @@ -796,7 +802,7 @@ class TestWebAppAuthService: ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") def test_is_app_require_permission_check_no_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement with no parameters. @@ -814,7 +820,7 @@ class TestWebAppAuthService: assert "Either app_code or app_id must be provided." in str(exc_info.value) def test_get_app_auth_type_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for public access mode. @@ -833,7 +839,7 @@ class TestWebAppAuthService: assert result == WebAppAuthType.PUBLIC def test_get_app_auth_type_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for private access mode. @@ -851,7 +857,9 @@ class TestWebAppAuthService: # Assert: Verify correct result assert result == WebAppAuthType.INTERNAL - def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_with_app_code( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type using app code. @@ -878,7 +886,9 @@ class TestWebAppAuthService: "enterprise_service" ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") - def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_no_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type with no parameters. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index ce25eec6f0..a3440b6b67 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun @@ -48,7 +49,7 @@ class TestWorkflowAppService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -96,7 +97,7 @@ class TestWorkflowAppService: return app, account - def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test tenant and account for testing. @@ -126,7 +127,7 @@ class TestWorkflowAppService: return tenant, account - def _create_test_app(self, db_session_with_containers, tenant, account): + def _create_test_app(self, db_session_with_containers: Session, tenant, account): """ Helper method to create a test app for testing. @@ -160,7 +161,7 @@ class TestWorkflowAppService: return app - def _create_test_workflow_data(self, db_session_with_containers, app, account): + def _create_test_workflow_data(self, db_session_with_containers: Session, app, account): """ Helper method to create test workflow data for testing. @@ -174,8 +175,6 @@ class TestWorkflowAppService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -188,8 +187,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run workflow_run = WorkflowRun( @@ -212,8 +211,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -227,13 +226,13 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() return workflow, workflow_run, workflow_app_log def test_get_paginate_workflow_app_logs_basic_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of workflow app logs with basic parameters. @@ -268,13 +267,12 @@ class TestWorkflowAppService: assert log_entry.workflow_run_id == workflow_run.id # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow_app_log) + db_session_with_containers.refresh(workflow_app_log) assert workflow_app_log.id is not None def test_get_paginate_workflow_app_logs_with_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with keyword search functionality. @@ -287,11 +285,10 @@ class TestWorkflowAppService: ) # Update workflow run with searchable content - from extensions.ext_database import db workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"}) workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"}) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword search service = WorkflowAppService() @@ -317,7 +314,7 @@ class TestWorkflowAppService: assert len(result_no_match["data"]) == 0 def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. @@ -332,8 +329,6 @@ class TestWorkflowAppService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) - from extensions.ext_database import db - service = WorkflowAppService() # Test 1: Search with % character @@ -353,8 +348,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_1) - db.session.flush() + db_session_with_containers.add(workflow_run_1) + db_session_with_containers.flush() workflow_app_log_1 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -367,8 +362,8 @@ class TestWorkflowAppService: ) workflow_app_log_1.id = str(uuid.uuid4()) workflow_app_log_1.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_1) - db.session.commit() + db_session_with_containers.add(workflow_app_log_1) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -395,8 +390,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_2) - db.session.flush() + db_session_with_containers.add(workflow_run_2) + db_session_with_containers.flush() workflow_app_log_2 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -409,8 +404,8 @@ class TestWorkflowAppService: ) workflow_app_log_2.id = str(uuid.uuid4()) workflow_app_log_2.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_2) - db.session.commit() + db_session_with_containers.add(workflow_app_log_2) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 @@ -437,8 +432,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_4) - db.session.flush() + db_session_with_containers.add(workflow_run_4) + db_session_with_containers.flush() workflow_app_log_4 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -451,8 +446,8 @@ class TestWorkflowAppService: ) workflow_app_log_4.id = str(uuid.uuid4()) workflow_app_log_4.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_4) - db.session.commit() + db_session_with_containers.add(workflow_app_log_4) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -467,7 +462,7 @@ class TestWorkflowAppService: assert workflow_run_4.id not in found_run_ids def test_get_paginate_workflow_app_logs_with_status_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with status filtering. @@ -476,8 +471,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -490,8 +483,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different statuses statuses = ["succeeded", "failed", "running", "stopped"] @@ -519,8 +512,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -533,8 +526,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -568,7 +561,7 @@ class TestWorkflowAppService: assert result_running["data"][0].workflow_run.status == "running" def test_get_paginate_workflow_app_logs_with_time_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with time-based filtering. @@ -577,8 +570,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -591,8 +582,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different timestamps base_time = datetime.now(UTC) @@ -627,8 +618,8 @@ class TestWorkflowAppService: created_at=timestamp, finished_at=timestamp + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -641,8 +632,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = timestamp - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -682,7 +673,7 @@ class TestWorkflowAppService: assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago def test_get_paginate_workflow_app_logs_with_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with different page sizes and limits. @@ -691,8 +682,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -705,8 +694,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create 25 workflow runs and logs total_logs = 25 @@ -734,8 +723,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -748,8 +737,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -798,7 +787,7 @@ class TestWorkflowAppService: assert len(result_large_limit["data"]) == total_logs def test_get_paginate_workflow_app_logs_with_user_role_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with user role and session filtering. @@ -807,8 +796,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -821,8 +808,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create end user end_user = EndUser( @@ -835,8 +822,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create workflow runs and logs for both account and end user workflow_runs = [] @@ -864,8 +851,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -878,8 +865,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -906,8 +893,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i + 10), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -920,8 +907,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -994,7 +981,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with UUID keyword search functionality. @@ -1003,8 +990,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1017,8 +1002,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with specific UUID workflow_run_id = str(uuid.uuid4()) @@ -1042,8 +1027,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC) + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1057,8 +1042,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test UUID keyword search service = WorkflowAppService() @@ -1085,7 +1070,7 @@ class TestWorkflowAppService: assert result_invalid_uuid["total"] == 0 def test_get_paginate_workflow_app_logs_with_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with edge cases and boundary conditions. @@ -1094,8 +1079,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1108,8 +1091,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with edge case data workflow_run = WorkflowRun( @@ -1132,8 +1115,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1147,8 +1130,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test edge cases service = WorkflowAppService() @@ -1185,7 +1168,7 @@ class TestWorkflowAppService: assert result_high_page["has_more"] is False def test_get_paginate_workflow_app_logs_with_empty_results( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with empty results and no data scenarios. @@ -1252,7 +1235,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_complex_query_combinations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with complex query combinations. @@ -1352,7 +1335,7 @@ class TestWorkflowAppService: assert len(result_time_status_limit["data"]) <= 2 def test_get_paginate_workflow_app_logs_with_large_dataset_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with large dataset for performance validation. @@ -1444,7 +1427,7 @@ class TestWorkflowAppService: assert result_last_page["page"] == 3 def test_get_paginate_workflow_app_logs_with_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with proper tenant isolation. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 624251cd6c..ab409deb89 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,5 +1,6 @@ import pytest from faker import Faker +from sqlalchemy.orm import Session from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.segments import StringSegment @@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): """ Helper method to create a test app with realistic data for testing. @@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService: app.created_by = fake.uuid4() app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): """ Helper method to create a test workflow associated with an app. @@ -110,15 +109,14 @@ class TestWorkflowDraftVariableService: conversation_variables=[], rag_pipeline_variables=[], ) - from extensions.ext_database import db - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow def _create_test_variable( self, - db_session_with_containers, + db_session_with_containers: Session, app_id, node_id, name, @@ -174,13 +172,12 @@ class TestWorkflowDraftVariableService: visible=True, editable=True, ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() return variable - def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a single variable by ID successfully. @@ -202,7 +199,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable.app_id == app.id assert retrieved_variable.get_value().value == test_value.value - def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a variable that doesn't exist. @@ -217,7 +214,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable is None def test_get_draft_variables_by_selectors_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting variables by selectors successfully. @@ -268,7 +265,7 @@ class TestWorkflowDraftVariableService: assert var.get_value().value == var3_value.value def test_list_variables_without_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test listing variables without values successfully with pagination. @@ -300,7 +297,7 @@ class TestWorkflowDraftVariableService: assert var.name is not None assert var.app_id == app.id - def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test listing variables for a specific node successfully. @@ -352,7 +349,9 @@ class TestWorkflowDraftVariableService: assert "var2" in var_names assert "var3" not in var_names - def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_conversation_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing conversation variables successfully. @@ -393,7 +392,7 @@ class TestWorkflowDraftVariableService: assert "conv_var2" in var_names assert "sys_var" not in var_names - def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating a variable's name and value successfully. @@ -418,14 +417,15 @@ class TestWorkflowDraftVariableService: assert updated_variable.name == "new_name" assert updated_variable.get_value().value == new_value.value assert updated_variable.last_edited_at is not None - from extensions.ext_database import db - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.name == "new_name" assert variable.get_value().value == new_value.value assert variable.last_edited_at is not None - def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_not_editable( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that updating a non-editable variable raises an exception. @@ -445,17 +445,18 @@ class TestWorkflowDraftVariableService: node_execution_id=fake.uuid4(), editable=False, # Set as non-editable ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) with pytest.raises(UpdateNotSupportedError) as exc_info: service.update_variable(variable, name="new_name", value=new_value) assert "variable not support updating" in str(exc_info.value) assert variable.id in str(exc_info.value) - def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_reset_conversation_variable_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test resetting conversation variable successfully. @@ -476,9 +477,8 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], ) workflow.conversation_variables = [conv_var] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() modified_value = StringSegment(value=fake.word()) variable = self._create_test_variable( db_session_with_containers, @@ -489,17 +489,17 @@ class TestWorkflowDraftVariableService: fake=fake, ) variable.last_edited_at = fake.date_time() - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) reset_variable = service.reset_variable(workflow, variable) assert reset_variable is not None assert reset_variable.get_value().value == "default_value" assert reset_variable.last_edited_at is None - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.get_value().value == "default_value" assert variable.last_edited_at is None - def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test deleting a single variable successfully. @@ -513,14 +513,15 @@ class TestWorkflowDraftVariableService: variable = self._create_test_variable( db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake ) - from extensions.ext_database import db - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None service = WorkflowDraftVariableService(db_session_with_containers) service.delete_variable(variable) - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None - def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_workflow_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a workflow successfully. @@ -550,20 +551,25 @@ class TestWorkflowDraftVariableService: other_value, fake=fake, ) - from extensions.ext_database import db - app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables) == 3 assert len(other_app_variables) == 1 service = WorkflowDraftVariableService(db_session_with_containers) service.delete_workflow_variables(app.id) - app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables_after) == 0 assert len(other_app_variables_after) == 1 - def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_node_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a specific node successfully. @@ -605,14 +611,15 @@ class TestWorkflowDraftVariableService: conv_value, fake=fake, ) - from extensions.ext_database import db - target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + target_node_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) other_node_variables = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -622,13 +629,13 @@ class TestWorkflowDraftVariableService: service = WorkflowDraftVariableService(db_session_with_containers) service.delete_node_variables(app.id, node_id) target_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() ) other_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables_after = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -637,7 +644,7 @@ class TestWorkflowDraftVariableService: assert len(conv_variables_after) == 1 def test_prefill_conversation_variable_default_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prefill conversation variable default values successfully. @@ -665,13 +672,12 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], ) workflow.conversation_variables = [conv_var1, conv_var2] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) service.prefill_conversation_variable_default_values(workflow) draft_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -686,7 +692,7 @@ class TestWorkflowDraftVariableService: assert var.get_variable_type() == DraftVariableType.CONVERSATION def test_get_conversation_id_from_draft_variable_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID from draft variable successfully. @@ -713,7 +719,7 @@ class TestWorkflowDraftVariableService: assert retrieved_conv_id == conversation_id def test_get_conversation_id_from_draft_variable_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID when it doesn't exist. @@ -728,7 +734,9 @@ class TestWorkflowDraftVariableService: retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) assert retrieved_conv_id is None - def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_system_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing system variables successfully. @@ -775,7 +783,9 @@ class TestWorkflowDraftVariableService: assert "sys_var2" in var_names assert "conv_var" not in var_names - def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name successfully for different types. @@ -822,7 +832,9 @@ class TestWorkflowDraftVariableService: assert retrieved_node_var.name == "test_node_var" assert retrieved_node_var.node_id == "test_node" - def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name when they don't exist. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 3a88081db3..38ef3975b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.enums import CreatorUserRole from models.model import ( @@ -48,7 +49,7 @@ class TestWorkflowRunService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -94,7 +95,7 @@ class TestWorkflowRunService: return app, account def _create_test_workflow_run( - self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0 + self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0 ): """ Helper method to create a test workflow run for testing. @@ -110,8 +111,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow run with offset timestamp base_time = datetime.now(UTC) created_time = base_time - timedelta(minutes=offset_minutes) @@ -136,12 +135,12 @@ class TestWorkflowRunService: finished_at=created_time, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() return workflow_run - def _create_test_message(self, db_session_with_containers, app, account, workflow_run): + def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run): """ Helper method to create a test message for testing. @@ -156,8 +155,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation first (required for message) from models.model import Conversation @@ -170,8 +167,8 @@ class TestWorkflowRunService: from_source=CreatorUserRole.ACCOUNT, from_account_id=account.id, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message() @@ -193,12 +190,14 @@ class TestWorkflowRunService: message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_runs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination of workflow runs with debugging trigger. @@ -239,7 +238,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_with_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with last_id parameter. @@ -282,7 +281,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_default_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with default limit. @@ -320,7 +319,7 @@ class TestWorkflowRunService: assert workflow_run_result.tenant_id == app.tenant_id def test_get_paginate_advanced_chat_workflow_runs_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of advanced chat workflow runs with message information. @@ -365,7 +364,7 @@ class TestWorkflowRunService: assert workflow_run.app_id == app.id assert workflow_run.tenant_id == app.tenant_id - def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of workflow run by ID. @@ -395,7 +394,7 @@ class TestWorkflowRunService: assert result.type == "chat" assert result.version == "1.0.0" - def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test workflow run retrieval when run ID does not exist. @@ -419,7 +418,7 @@ class TestWorkflowRunService: assert result is None def test_get_workflow_run_node_executions_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of workflow run node executions. @@ -438,7 +437,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create node executions - from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionModel node_executions = [] @@ -462,7 +460,7 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) + db_session_with_containers.add(node_execution) node_executions.append(node_execution) paused_node_execution = WorkflowNodeExecutionModel( @@ -484,9 +482,9 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(paused_node_execution) + db_session_with_containers.add(paused_node_execution) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() @@ -509,7 +507,7 @@ class TestWorkflowRunService: assert node_execution.node_id.startswith("node_") def test_get_workflow_run_node_executions_empty( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions for a workflow run with no executions. @@ -560,7 +558,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_invalid_workflow_run_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions with invalid workflow run ID. @@ -611,7 +609,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_database_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions when database encounters an error. @@ -662,7 +660,7 @@ class TestWorkflowRunService: ) def test_get_workflow_run_node_executions_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test node execution retrieval for end user. @@ -680,7 +678,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create end user - from extensions.ext_database import db from models.model import EndUser end_user = EndUser( @@ -692,8 +689,8 @@ class TestWorkflowRunService: external_user_id=str(uuid.uuid4()), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create node execution from models.workflow import WorkflowNodeExecutionModel @@ -717,8 +714,8 @@ class TestWorkflowRunService: created_by=end_user.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index ef575a9b69..bfb23bac68 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -10,6 +10,7 @@ from unittest.mock import MagicMock import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, App, Workflow from models.model import AppMode @@ -32,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -67,18 +68,16 @@ class TestWorkflowService: tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at - from extensions.ext_database import db - - db.session.add(tenant) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.add(account) + db_session_with_containers.commit() # Set the current tenant for the account account.current_tenant = tenant return account - def _create_test_app(self, db_session_with_containers, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test app with realistic data. @@ -106,13 +105,11 @@ class TestWorkflowService: ) app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): """ Helper method to create a test workflow associated with an app. @@ -141,13 +138,11 @@ class TestWorkflowService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_get_node_last_run_success(self, db_session_with_containers): + def test_get_node_last_run_success(self, db_session_with_containers: Session): """ Test successful retrieval of the most recent execution for a specific node. @@ -180,10 +175,8 @@ class TestWorkflowService: node_execution.created_by = account.id # Required field node_execution.created_at = fake.date_time_this_year() - from extensions.ext_database import db - - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -196,7 +189,7 @@ class TestWorkflowService: assert result.workflow_id == workflow.id assert result.status == "succeeded" - def test_get_node_last_run_not_found(self, db_session_with_containers): + def test_get_node_last_run_not_found(self, db_session_with_containers: Session): """ Test retrieval when no execution record exists for the specified node. @@ -217,7 +210,7 @@ class TestWorkflowService: # Assert assert result is None - def test_is_workflow_exist_true(self, db_session_with_containers): + def test_is_workflow_exist_true(self, db_session_with_containers: Session): """ Test workflow existence check when a draft workflow exists. @@ -238,7 +231,7 @@ class TestWorkflowService: # Assert assert result is True - def test_is_workflow_exist_false(self, db_session_with_containers): + def test_is_workflow_exist_false(self, db_session_with_containers: Session): """ Test workflow existence check when no draft workflow exists. @@ -258,7 +251,7 @@ class TestWorkflowService: # Assert assert result is False - def test_get_draft_workflow_success(self, db_session_with_containers): + def test_get_draft_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of a draft workflow. @@ -284,7 +277,7 @@ class TestWorkflowService: assert result.app_id == app.id assert result.tenant_id == app.tenant_id - def test_get_draft_workflow_not_found(self, db_session_with_containers): + def test_get_draft_workflow_not_found(self, db_session_with_containers: Session): """ Test draft workflow retrieval when no draft workflow exists. @@ -304,7 +297,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_by_id_success(self, db_session_with_containers): + def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session): """ Test successful retrieval of a published workflow by ID. @@ -321,9 +314,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -336,7 +327,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers): + def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session): """ Test error when trying to retrieve a draft workflow as published. @@ -359,7 +350,7 @@ class TestWorkflowService: with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow.id) - def test_get_published_workflow_by_id_not_found(self, db_session_with_containers): + def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session): """ Test retrieval when no workflow exists with the specified ID. @@ -379,7 +370,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_success(self, db_session_with_containers): + def test_get_published_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of the current published workflow for an app. @@ -395,10 +386,8 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -411,7 +400,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_no_workflow_id(self, db_session_with_containers): + def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session): """ Test retrieval when app has no associated workflow ID. @@ -431,7 +420,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_all_published_workflow_pagination(self, db_session_with_containers): + def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session): """ Test pagination of published workflows. @@ -455,15 +444,13 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflows[0].id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - First page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=1, limit=3, @@ -476,7 +463,7 @@ class TestWorkflowService: # Act - Second page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=2, limit=3, @@ -487,7 +474,7 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert has_more is False - def test_get_all_published_workflow_user_filter(self, db_session_with_containers): + def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session): """ Test filtering published workflows by user. @@ -513,22 +500,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter by account1 result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=account1.id + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id ) # Assert assert len(result_workflows) == 1 assert result_workflows[0].created_by == account1.id - def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers): + def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session): """ Test filtering published workflows to show only named workflows. @@ -557,22 +542,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter named only result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True ) # Assert assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) - def test_sync_draft_workflow_create_new(self, db_session_with_containers): + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. @@ -624,7 +607,7 @@ class TestWorkflowService: assert result.features == json.dumps(features) assert result.created_by == account.id - def test_sync_draft_workflow_update_existing(self, db_session_with_containers): + def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session): """ Test updating an existing draft workflow through sync operation. @@ -688,7 +671,7 @@ class TestWorkflowService: assert result.features == json.dumps(new_features) assert result.updated_by == account.id - def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers): + def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session): """ Test error when sync is attempted with mismatched hash. @@ -738,7 +721,7 @@ class TestWorkflowService: conversation_variables=conversation_variables, ) - def test_publish_workflow_success(self, db_session_with_containers): + def test_publish_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow publishing. @@ -755,9 +738,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = Workflow.VERSION_DRAFT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -777,7 +758,7 @@ class TestWorkflowService: assert len(result.version) > 10 # Should be a reasonable timestamp length assert result.created_by == account.id - def test_publish_workflow_no_draft_error(self, db_session_with_containers): + def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session): """ Test error when publishing workflow without draft. @@ -797,7 +778,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_publish_workflow_already_published_error(self, db_session_with_containers): + def test_publish_workflow_already_published_error(self, db_session_with_containers: Session): """ Test error when publishing already published workflow. @@ -813,9 +794,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Already published - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -823,7 +802,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_get_default_block_configs(self, db_session_with_containers): + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. @@ -847,7 +826,7 @@ class TestWorkflowService: assert isinstance(config, dict) # The structure can vary, so we just check it's a dict - def test_get_default_block_config_specific_type(self, db_session_with_containers): + def test_get_default_block_config_specific_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for a specific node type. @@ -867,7 +846,7 @@ class TestWorkflowService: # This is acceptable behavior assert result is None or isinstance(result, dict) - def test_get_default_block_config_invalid_type(self, db_session_with_containers): + def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for invalid node type. @@ -887,7 +866,7 @@ class TestWorkflowService: # It's also acceptable for the service to raise a ValueError for invalid types pass - def test_get_default_block_config_with_filters(self, db_session_with_containers): + def test_get_default_block_config_with_filters(self, db_session_with_containers: Session): """ Test retrieval of default block configuration with filters. @@ -907,7 +886,7 @@ class TestWorkflowService: # Result might be None if filters don't match, but should not raise error assert result is None or isinstance(result, dict) - def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from chat mode app to workflow mode. @@ -944,11 +923,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -969,7 +946,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from completion mode app to workflow mode. @@ -1006,11 +983,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -1031,7 +1006,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers): + def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session): """ Test error when attempting to convert unsupported app mode. @@ -1046,9 +1021,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = {"name": "Test"} @@ -1057,7 +1030,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"): workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args) - def test_validate_features_structure_advanced_chat(self, db_session_with_containers): + def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session): """ Test feature structure validation for advanced chat mode apps. @@ -1069,9 +1042,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.ADVANCED_CHAT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = { @@ -1088,7 +1059,7 @@ class TestWorkflowService: # The exact behavior depends on the AdvancedChatAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_workflow(self, db_session_with_containers): + def test_validate_features_structure_workflow(self, db_session_with_containers: Session): """ Test feature structure validation for workflow mode apps. @@ -1100,9 +1071,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"workflow_config": {"max_steps": 10, "timeout": 300}} @@ -1115,7 +1084,7 @@ class TestWorkflowService: # The exact behavior depends on the WorkflowAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_invalid_mode(self, db_session_with_containers): + def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session): """ Test error when validating features for invalid app mode. @@ -1127,9 +1096,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"test": "value"} @@ -1138,7 +1105,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): workflow_service.validate_features_structure(app_model=app, features=features) - def test_update_workflow_success(self, db_session_with_containers): + def test_update_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow update with allowed fields. @@ -1152,16 +1119,14 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1174,7 +1139,7 @@ class TestWorkflowService: assert result.marked_comment == update_data["marked_comment"] assert result.updated_by == account.id - def test_update_workflow_not_found(self, db_session_with_containers): + def test_update_workflow_not_found(self, db_session_with_containers: Session): """ Test workflow update when workflow doesn't exist. @@ -1186,15 +1151,13 @@ class TestWorkflowService: account = self._create_test_account(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake) - from extensions.ext_database import db - workflow_service = WorkflowService() non_existent_workflow_id = fake.uuid4() update_data = {"marked_name": "Test"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id, account_id=account.id, @@ -1204,7 +1167,7 @@ class TestWorkflowService: # Assert assert result is None - def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers): + def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session): """ Test that workflow update ignores disallowed fields. @@ -1218,9 +1181,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) original_name = workflow.marked_name - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = { @@ -1231,7 +1192,7 @@ class TestWorkflowService: # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1245,7 +1206,7 @@ class TestWorkflowService: assert result.graph == workflow.graph assert result.features == workflow.features - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow deletion. @@ -1262,25 +1223,23 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act result = workflow_service.delete_workflow( - session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id ) # Assert assert result is True # Verify workflow is actually deleted - deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first() + deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first() assert deleted_workflow is None - def test_delete_workflow_draft_error(self, db_session_with_containers): + def test_delete_workflow_draft_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a draft workflow. @@ -1296,9 +1255,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) # Keep as draft version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1306,9 +1263,11 @@ class TestWorkflowService: from services.errors.workflow_service import DraftWorkflowDeletionError with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_in_use_error(self, db_session_with_containers): + def test_delete_workflow_in_use_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a workflow that's in use by an app. @@ -1327,9 +1286,7 @@ class TestWorkflowService: # Associate workflow with app app.workflow_id = workflow.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1337,9 +1294,11 @@ class TestWorkflowService: from services.errors.workflow_service import WorkflowInUseError with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_not_found_error(self, db_session_with_containers): + def test_delete_workflow_not_found_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a non-existent workflow. @@ -1351,17 +1310,15 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) non_existent_workflow_id = fake.uuid4() - from extensions.ext_database import db - workflow_service = WorkflowService() # Act & Assert with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"): workflow_service.delete_workflow( - session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id ) - def test_run_free_workflow_node_success(self, db_session_with_containers): + def test_run_free_workflow_node_success(self, db_session_with_containers: Session): """ Test successful execution of a free workflow node. @@ -1413,7 +1370,7 @@ class TestWorkflowService: assert result.workflow_id == "" # No workflow ID for free nodes assert result.index == 1 - def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers): + def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session): """ Test execution of a free workflow node with complex input data. @@ -1454,7 +1411,7 @@ class TestWorkflowService: error_msg = str(exc_info.value).lower() assert any(keyword in error_msg for keyword in ["start", "not supported", "external"]) - def test_handle_node_run_result_success(self, db_session_with_containers): + def test_handle_node_run_result_success(self, db_session_with_containers: Session): """ Test successful handling of node run results. @@ -1529,7 +1486,7 @@ class TestWorkflowService: assert result.outputs is not None assert result.process_data is not None - def test_handle_node_run_result_failure(self, db_session_with_containers): + def test_handle_node_run_result_failure(self, db_session_with_containers: Session): """ Test handling of failed node run results. @@ -1598,7 +1555,7 @@ class TestWorkflowService: assert result.error is not None assert "Test error message" in str(result.error) - def test_handle_node_run_result_continue_on_error(self, db_session_with_containers): + def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session): """ Test handling of node run results with continue_on_error strategy. diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 4249642bc9..92dec24c7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.workspace_service import WorkspaceService @@ -29,7 +30,7 @@ class TestWorkspaceService: "dify_config": mock_dify_config, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,10 +51,8 @@ class TestWorkspaceService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +61,8 @@ class TestWorkspaceService: plan="basic", custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}', ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +71,15 @@ class TestWorkspaceService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tenant information with all features enabled. @@ -121,13 +120,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_without_custom_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval when custom config features are disabled. @@ -167,13 +165,12 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_normal_user_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for normal user role without privileged features. @@ -191,11 +188,14 @@ class TestWorkspaceService: ) # Update the join to have normal role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.NORMAL - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -220,11 +220,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_admin_role_and_logo_replacement( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for admin role with logo replacement enabled. @@ -242,11 +242,14 @@ class TestWorkspaceService: ) # Update the join to have admin role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.ADMIN - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -268,10 +271,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None - def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_with_tenant_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant info retrieval when tenant parameter is None. @@ -290,7 +295,7 @@ class TestWorkspaceService: assert result is None def test_get_tenant_info_with_custom_config_variations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with various custom config configurations. @@ -323,10 +328,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -353,11 +356,11 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_editor_role_and_limited_permissions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for editor role with limited permissions. @@ -375,11 +378,14 @@ class TestWorkspaceService: ) # Update the join to have editor role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.EDITOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -400,11 +406,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_dataset_operator_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for dataset operator role. @@ -422,11 +428,14 @@ class TestWorkspaceService: ) # Update the join to have dataset operator role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.DATASET_OPERATOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -447,11 +456,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_complex_custom_config_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with complex custom config scenarios. @@ -491,10 +500,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -525,5 +532,5 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] is False # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 2ff71ea6ea..bffdca623a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import TypeAdapter, ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant @@ -34,7 +35,7 @@ class TestApiToolManageService: "provider_controller": mock_provider_controller, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -55,18 +56,16 @@ class TestApiToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -77,8 +76,8 @@ class TestApiToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -118,7 +117,7 @@ class TestApiToolManageService: """ def test_parser_api_schema_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful parsing of API schema. @@ -163,7 +162,7 @@ class TestApiToolManageService: assert api_key_value_field["default"] == "" def test_parser_api_schema_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of invalid API schema. @@ -183,7 +182,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_parser_api_schema_malformed_json( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of malformed JSON schema. @@ -203,7 +202,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_convert_schema_to_tool_bundles_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles. @@ -233,7 +232,7 @@ class TestApiToolManageService: assert tool_bundle.operation_id == "testOperation" def test_convert_schema_to_tool_bundles_with_extra_info( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles with extra info. @@ -259,7 +258,7 @@ class TestApiToolManageService: assert isinstance(schema_type, str) def test_convert_schema_to_tool_bundles_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of invalid schema to tool bundles. @@ -279,7 +278,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_create_api_tool_provider_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider. @@ -324,10 +323,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) @@ -347,7 +345,7 @@ class TestApiToolManageService: mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once() def test_create_api_tool_provider_duplicate_name( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with duplicate name. @@ -404,7 +402,7 @@ class TestApiToolManageService: assert f"provider {provider_name} already exists" in str(exc_info.value) def test_create_api_tool_provider_invalid_schema_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with invalid schema type. @@ -436,7 +434,7 @@ class TestApiToolManageService: assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with missing auth type. @@ -479,7 +477,7 @@ class TestApiToolManageService: assert "auth_type is required" in str(exc_info.value) def test_create_api_tool_provider_with_api_key_auth( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider with API key authentication. @@ -522,10 +520,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 6cae83ac37..0f2e3980af 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ToolProviderType from models import Account, Tenant @@ -41,7 +42,7 @@ class TestMCPToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -62,18 +63,16 @@ class TestMCPToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -84,8 +83,8 @@ class TestMCPToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -93,7 +92,7 @@ class TestMCPToolManageService: return account, tenant def _create_test_mcp_provider( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id ): """ Helper method to create a test MCP tool provider for testing. @@ -124,15 +123,13 @@ class TestMCPToolManageService: sse_read_timeout=300.0, ) - from extensions.ext_database import db - - db.session.add(mcp_provider) - db.session.commit() + db_session_with_containers.add(mcp_provider) + db_session_with_containers.commit() return mcp_provider def test_get_mcp_provider_by_provider_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by provider ID. @@ -153,9 +150,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -166,12 +162,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.server_identifier == mcp_provider.server_identifier def test_get_mcp_provider_by_provider_id_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by provider ID. @@ -190,14 +186,13 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id) def test_get_mcp_provider_by_provider_id_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by provider ID. @@ -223,14 +218,13 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id) def test_get_mcp_provider_by_server_identifier_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by server identifier. @@ -251,9 +245,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -264,12 +257,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.name == mcp_provider.name def test_get_mcp_provider_by_server_identifier_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by server identifier. @@ -288,14 +281,13 @@ class TestMCPToolManageService: non_existent_identifier = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id) def test_get_mcp_provider_by_server_identifier_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by server identifier. @@ -321,13 +313,12 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id) - def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of MCP provider. @@ -365,9 +356,8 @@ class TestMCPToolManageService: # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -389,10 +379,9 @@ class TestMCPToolManageService: assert result.type == ToolProviderType.MCP # Verify database state - from extensions.ext_database import db created_provider = ( - db.session.query(MCPToolProvider) + db_session_with_containers.query(MCPToolProvider) .filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider") .first() ) @@ -410,7 +399,9 @@ class TestMCPToolManageService: ) mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once() - def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when creating MCP provider with duplicate name. @@ -427,9 +418,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -463,7 +453,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server URL. @@ -481,9 +471,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -517,7 +506,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_identifier( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server identifier. @@ -535,9 +524,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -570,7 +558,7 @@ class TestMCPToolManageService: ), ) - def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of MCP tools for a tenant. @@ -602,9 +590,7 @@ class TestMCPToolManageService: ) provider3.name = "Gamma Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Setup mock for transformation service from core.tools.entities.api_entities import ToolProviderApiEntity @@ -647,9 +633,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=True) # Assert: Verify the expected outcomes @@ -666,7 +651,9 @@ class TestMCPToolManageService: mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3 ) - def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of MCP tools when tenant has no providers. @@ -684,9 +671,8 @@ class TestMCPToolManageService: # No MCP providers created for this tenant # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=False) # Assert: Verify the expected outcomes @@ -697,7 +683,9 @@ class TestMCPToolManageService: # Verify no transformation service calls for empty list mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called() - def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when retrieving MCP tools. @@ -756,9 +744,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test for both tenants - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result1 = service.list_providers(tenant_id=tenant1.id, for_list=True) result2 = service.list_providers(tenant_id=tenant2.id, for_list=True) @@ -769,7 +756,7 @@ class TestMCPToolManageService: assert result2[0].id == provider2.id def test_list_mcp_tool_from_remote_server_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful listing of MCP tools from remote server. @@ -797,9 +784,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to list tools mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -821,9 +806,8 @@ class TestMCPToolManageService: mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes @@ -834,7 +818,7 @@ class TestMCPToolManageService: # Note: server_url is mocked, so we skip that assertion to avoid encryption issues # Verify database state was updated - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.tools != "[]" assert mcp_provider.updated_at is not None @@ -844,7 +828,7 @@ class TestMCPToolManageService: mock_mcp_client.assert_called_once() def test_list_mcp_tool_from_remote_server_auth_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server requires authentication. @@ -871,9 +855,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -887,19 +869,18 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Please auth the tool first"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" def test_list_mcp_tool_from_remote_server_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server connection fails. @@ -926,9 +907,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to test connection errors mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -942,18 +921,17 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True # Provider remains authenticated assert mcp_provider.tools == "[]" - def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of MCP tool. @@ -974,20 +952,19 @@ class TestMCPToolManageService: ) # Verify provider exists - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None # Act: Execute the method under test - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes # Provider should be deleted from database - deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() + deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() assert deleted_provider is None - def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when deleting non-existent MCP tool. @@ -1005,13 +982,14 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id) - def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when deleting MCP tool. @@ -1036,18 +1014,16 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id) # Verify provider still exists in tenant1 - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None - def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful update of MCP provider. @@ -1070,14 +1046,12 @@ class TestMCPToolManageService: original_name = mcp_provider.name original_icon = mcp_provider.icon - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider( tenant_id=tenant.id, provider_id=mcp_provider.id, @@ -1094,7 +1068,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.name == "Updated MCP Provider" assert mcp_provider.server_identifier == "updated_identifier_123" assert mcp_provider.timeout == 45.0 @@ -1108,7 +1082,9 @@ class TestMCPToolManageService: assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when updating MCP provider with duplicate name. @@ -1134,15 +1110,12 @@ class TestMCPToolManageService: ) provider2.name = "Second Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Verify proper error handling for duplicate name from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool First Provider already exists"): service.update_provider( tenant_id=tenant.id, @@ -1160,7 +1133,7 @@ class TestMCPToolManageService: ) def test_update_mcp_provider_credentials_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of MCP provider credentials. @@ -1185,9 +1158,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1202,9 +1173,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1213,7 +1183,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.updated_at is not None @@ -1225,7 +1195,7 @@ class TestMCPToolManageService: assert "new_key" in credentials def test_update_mcp_provider_credentials_not_authed( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of MCP provider credentials when not authenticated. @@ -1249,9 +1219,7 @@ class TestMCPToolManageService: mcp_provider.authed = True mcp_provider.tools = '[{"name": "test_tool"}]' - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1266,9 +1234,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1277,12 +1244,14 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" assert mcp_provider.updated_at is not None - def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful reconnection to MCP provider. @@ -1343,7 +1312,9 @@ class TestMCPToolManageService: sse_read_timeout=mcp_provider.sse_read_timeout, ) - def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_auth_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test reconnection to MCP provider when authentication fails. @@ -1385,7 +1356,7 @@ class TestMCPToolManageService: assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test reconnection to MCP provider when connection fails. diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index fa13790942..f3736333ea 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -27,7 +28,7 @@ class TestToolTransformService: } def _create_test_tool_provider( - self, db_session_with_containers, mock_external_service_dependencies, provider_type="api" + self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api" ): """ Helper method to create a test tool provider for testing. @@ -89,14 +90,12 @@ class TestToolTransformService: else: raise ValueError(f"Unknown provider type: {provider_type}") - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider - def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful plugin icon URL generation. @@ -126,7 +125,7 @@ class TestToolTransformService: assert result == expected_url def test_get_plugin_icon_url_with_empty_console_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test plugin icon URL generation when CONSOLE_API_URL is empty. @@ -156,7 +155,7 @@ class TestToolTransformService: assert result == expected_url def test_get_tool_provider_icon_url_builtin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for builtin providers. @@ -194,7 +193,7 @@ class TestToolTransformService: assert result == expected_encoded def test_get_tool_provider_icon_url_api_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for API providers. @@ -220,7 +219,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_api_invalid_json( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for API providers with invalid JSON. @@ -246,7 +245,7 @@ class TestToolTransformService: assert result["content"] == "😁" or result["content"] == "\ud83d\ude01" def test_get_tool_provider_icon_url_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for workflow providers. @@ -271,7 +270,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_mcp_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for MCP providers. @@ -296,7 +295,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_unknown_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for unknown provider types. @@ -317,7 +316,9 @@ class TestToolTransformService: # Assert: Verify the expected outcomes assert result == "" - def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_dict_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with dictionary input. @@ -341,7 +342,9 @@ class TestToolTransformService: # Note: provider name may contain spaces that get URL encoded assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"] - def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with ToolProviderApiEntity input. @@ -389,7 +392,7 @@ class TestToolTransformService: assert "test_icon_dark.png" in provider.icon_dark def test_repack_provider_entity_no_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful provider repacking with ToolProviderApiEntity input without plugin_id. @@ -435,7 +438,9 @@ class TestToolTransformService: assert provider.icon_dark["background"] == "#252525" assert provider.icon_dark["content"] == "🔧" - def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_no_dark_icon( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test provider repacking with ToolProviderApiEntity input without dark icon. @@ -477,7 +482,7 @@ class TestToolTransformService: assert provider.icon_dark == "" def test_builtin_provider_to_user_provider_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider. @@ -545,7 +550,7 @@ class TestToolTransformService: assert result.original_credentials == {"api_key": "decrypted_key"} def test_builtin_provider_to_user_provider_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider with plugin. @@ -589,7 +594,7 @@ class TestToolTransformService: assert result.allow_delete is False def test_builtin_provider_to_user_provider_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of builtin provider to user provider without credentials. @@ -630,7 +635,9 @@ class TestToolTransformService: assert result.allow_delete is False assert result.masked_credentials == {"api_key": ""} - def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_api_provider_to_controller_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion of API provider to controller. @@ -655,10 +662,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -669,7 +674,7 @@ class TestToolTransformService: # Additional assertions would depend on the actual controller implementation def test_api_provider_to_controller_api_key_query( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with api_key_query auth type. @@ -693,10 +698,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -706,7 +709,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_api_provider_to_controller_backward_compatibility( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with backward compatibility auth types. @@ -731,10 +734,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -744,7 +745,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_workflow_provider_to_controller_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of workflow provider to controller. @@ -769,10 +770,8 @@ class TestToolTransformService: parameter_configuration="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Mock the WorkflowToolProviderController.from_db method to avoid app dependency with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db: diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 24fe5c4670..0b3c1112bd 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError @@ -63,7 +64,7 @@ class TestWorkflowToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -119,14 +120,12 @@ class TestWorkflowToolManageService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Update app to reference the workflow app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() return app, account, workflow @@ -153,7 +152,9 @@ class TestWorkflowToolManageService: ), ] - def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool creation with valid parameters. @@ -198,11 +199,10 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db # Check if workflow tool provider was created created_tool_provider = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -230,7 +230,7 @@ class TestWorkflowToolManageService: ].workflow_provider_to_controller.assert_called_once() def test_create_workflow_tool_duplicate_name_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when name already exists. @@ -280,10 +280,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -293,7 +292,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_invalid_app_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app does not exist. @@ -331,10 +330,9 @@ class TestWorkflowToolManageService: assert f"App {non_existent_app_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -344,7 +342,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_invalid_parameters_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when parameters are invalid. @@ -387,10 +385,9 @@ class TestWorkflowToolManageService: assert "validation error" in str(exc_info.value).lower() # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -400,7 +397,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_duplicate_app_id_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app_id already exists. @@ -450,10 +447,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -463,7 +459,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app has no workflow. @@ -481,10 +477,9 @@ class TestWorkflowToolManageService: ) # Remove workflow reference from app - from extensions.ext_database import db app.workflow_id = None - db.session.commit() + db_session_with_containers.commit() # Attempt to create workflow tool for app without workflow tool_parameters = self._create_test_workflow_tool_parameters() @@ -505,7 +500,7 @@ class TestWorkflowToolManageService: # Verify no workflow tool was created tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -515,7 +510,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when workflow contains human input nodes. @@ -558,10 +553,8 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - from extensions.ext_database import db - tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -570,7 +563,9 @@ class TestWorkflowToolManageService: assert tool_count == 0 - def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool update with valid parameters. @@ -603,10 +598,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -641,7 +635,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state was updated - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label @@ -658,7 +652,7 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() def test_update_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update fails when workflow contains human input nodes. @@ -689,10 +683,8 @@ class TestWorkflowToolManageService: parameters=initial_tool_parameters, ) - from extensions.ext_database import db - created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -712,7 +704,7 @@ class TestWorkflowToolManageService: ] } ) - db.session.commit() + db_session_with_containers.commit() with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: WorkflowToolManageService.update_workflow_tool( @@ -728,10 +720,12 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == original_name - def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_not_found_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test workflow tool update fails when tool does not exist. @@ -768,10 +762,9 @@ class TestWorkflowToolManageService: assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -781,7 +774,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_update_workflow_tool_same_name_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update succeeds when keeping the same name. @@ -813,10 +806,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -840,12 +832,12 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify tool still exists with the same name - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == first_tool_name assert created_tool.updated_at is not None def test_create_workflow_tool_with_file_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILE parameter having a file object as default. @@ -916,7 +908,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_with_files_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILES (Array[File]) parameter having file objects as default. @@ -991,7 +983,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_db_commit_before_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that database commit happens before validation, causing DB pollution on validation failure. @@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService: # Verify the tool was NOT created in database # This is the expected behavior (no pollution) - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name, diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 0c2ccaa051..8c007877fd 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -79,7 +80,7 @@ class TestWorkflowConverter: mock_config.app_model_config_dict = {} return mock_config - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -100,18 +101,16 @@ class TestWorkflowConverter: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -122,15 +121,17 @@ class TestWorkflowConverter: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account + ): """ Helper method to create a test app for testing. @@ -163,10 +164,8 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -177,16 +176,16 @@ class TestWorkflowConverter: created_by=account.id, updated_by=account.id, ) - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Link app model config to app app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() return app - def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful conversion of app to workflow. @@ -225,19 +224,18 @@ class TestWorkflowConverter: assert new_app.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(new_app) + db_session_with_containers.refresh(new_app) assert new_app.id is not None # Verify workflow was created - workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first() + workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first() assert workflow is not None assert workflow.tenant_id == app.tenant_id assert workflow.type == "chat" def test_convert_to_workflow_without_app_model_config_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is missing. @@ -270,16 +268,14 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling workflow_converter = WorkflowConverter() # Check initial state - initial_workflow_count = db.session.query(Workflow).count() + initial_workflow_count = db_session_with_containers.query(Workflow).count() with pytest.raises(ValueError, match="App model config is required"): workflow_converter.convert_to_workflow( @@ -294,12 +290,12 @@ class TestWorkflowConverter: # Verify database state remains unchanged # The workflow creation happens in convert_app_model_config_to_workflow # which is called before the app_model_config check, so we need to clean up - db.session.rollback() - final_workflow_count = db.session.query(Workflow).count() + db_session_with_containers.rollback() + final_workflow_count = db_session_with_containers.query(Workflow).count() assert final_workflow_count == initial_workflow_count def test_convert_app_model_config_to_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of app model config to workflow. @@ -356,16 +352,17 @@ class TestWorkflowConverter: assert answer_node["id"] == "answer" # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow) + db_session_with_containers.refresh(workflow) assert workflow.id is not None # Verify features were set features = json.loads(workflow._features) if workflow._features else {} assert isinstance(features, dict) - def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_start_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to start node. @@ -410,7 +407,9 @@ class TestWorkflowConverter: assert second_variable["label"] == "Number Input" assert second_variable["type"] == "number" - def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_http_request_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to HTTP request node. @@ -436,10 +435,8 @@ class TestWorkflowConverter: api_endpoint="https://api.example.com/test", ) - from extensions.ext_database import db - - db.session.add(api_based_extension) - db.session.commit() + db_session_with_containers.add(api_based_extension) + db_session_with_containers.commit() # Mock encrypter mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key" @@ -489,7 +486,7 @@ class TestWorkflowConverter: assert external_data_variable_node_mapping["external_data"] == code_node["id"] def test_convert_to_knowledge_retrieval_node_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion to knowledge retrieval node. diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 8bb536c34a..efeb29cf20 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment @@ -31,7 +31,9 @@ class TestAddDocumentToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +53,15 @@ class TestAddDocumentToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +70,8 @@ class TestAddDocumentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -81,8 +83,8 @@ class TestAddDocumentToIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -99,15 +101,15 @@ class TestAddDocumentToIndexTask: enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document - def _create_test_segments(self, db_session_with_containers, document, dataset): + def _create_test_segments(self, db_session_with_containers: Session, document, dataset): """ Helper method to create test document segments. @@ -138,13 +140,15 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments - def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_document_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with paragraph index type. @@ -180,9 +184,9 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -191,7 +195,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with different index types. @@ -209,10 +213,10 @@ class TestAddDocumentToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -237,9 +241,9 @@ class TestAddDocumentToIndexTask: assert len(documents) == 3 # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -248,7 +252,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -275,7 +279,7 @@ class TestAddDocumentToIndexTask: # because indexing_cache_key is not defined in that case def test_add_document_to_index_invalid_indexing_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid indexing status. @@ -294,7 +298,7 @@ class TestAddDocumentToIndexTask: # Set invalid indexing status document.indexing_status = "processing" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) @@ -304,7 +308,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_add_document_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when document's dataset doesn't exist. @@ -326,14 +330,14 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Delete the dataset to simulate dataset not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False assert document.indexing_status == "error" assert document.error is not None @@ -348,7 +352,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with parent-child structure. @@ -367,10 +371,10 @@ class TestAddDocumentToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -413,9 +417,9 @@ class TestAddDocumentToIndexTask: assert len(doc.children) == 2 # Each document has 2 children # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -424,7 +428,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_already_enabled_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing when segments are already enabled. @@ -459,10 +463,10 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -488,7 +492,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_auto_disable_log_deletion( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that auto disable logs are properly deleted during indexing. @@ -515,10 +519,10 @@ class TestAddDocumentToIndexTask: document_id=document.id, ) log_entry.id = str(fake.uuid4()) - db.session.add(log_entry) + db_session_with_containers.add(log_entry) auto_disable_logs.append(log_entry) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -526,7 +530,9 @@ class TestAddDocumentToIndexTask: # Verify logs exist before processing existing_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(existing_logs) == 2 @@ -535,7 +541,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify auto disable logs were deleted remaining_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(remaining_logs) == 0 @@ -547,14 +555,14 @@ class TestAddDocumentToIndexTask: # Verify segments were enabled for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify redis cache was cleared assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -584,7 +592,7 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False assert document.indexing_status == "error" assert document.error is not None @@ -593,14 +601,14 @@ class TestAddDocumentToIndexTask: # Verify segments were not enabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False # Should remain disabled due to error # Verify redis cache was still cleared despite error assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_segment_filtering_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment filtering with various edge cases. @@ -638,7 +646,7 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment1) + db_session_with_containers.add(segment1) segments.append(segment1) # Segment 2: Should be processed (enabled=True, status="completed") @@ -658,7 +666,7 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment2) + db_session_with_containers.add(segment2) segments.append(segment2) # Segment 3: Should NOT be processed (enabled=False, status="processing") @@ -677,7 +685,7 @@ class TestAddDocumentToIndexTask: status="processing", # Not completed created_by=document.created_by, ) - db.session.add(segment3) + db_session_with_containers.add(segment3) segments.append(segment3) # Segment 4: Should be processed (enabled=False, status="completed") @@ -696,10 +704,10 @@ class TestAddDocumentToIndexTask: status="completed", created_by=document.created_by, ) - db.session.add(segment4) + db_session_with_containers.add(segment4) segments.append(segment4) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -728,11 +736,11 @@ class TestAddDocumentToIndexTask: assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3 # Verify database state changes - db.session.refresh(document) - db.session.refresh(segment1) - db.session.refresh(segment2) - db.session.refresh(segment3) - db.session.refresh(segment4) + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + db_session_with_containers.refresh(segment3) + db_session_with_containers.refresh(segment4) # All segments should be enabled because the task updates ALL segments for the document assert segment1.enabled is True @@ -744,7 +752,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_comprehensive_error_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error scenarios and recovery. @@ -779,7 +787,7 @@ class TestAddDocumentToIndexTask: document.indexing_status = "completed" document.error = None document.disabled_at = None - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -789,7 +797,7 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify consistent error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False, f"Document should be disabled for {error_name}" assert document.indexing_status == "error", f"Document status should be error for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}" @@ -798,7 +806,7 @@ class TestAddDocumentToIndexTask: # Verify segments remain disabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False, f"Segments should remain disabled for {error_name}" # Verify redis cache was still cleared despite error diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index f94c5b19e6..ec789418a8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,8 +11,8 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -49,7 +49,7 @@ class TestBatchCleanDocumentTask: "get_image_ids": mock_get_image_ids, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -69,16 +69,16 @@ class TestBatchCleanDocumentTask: status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -87,15 +87,15 @@ class TestBatchCleanDocumentTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_dataset(self, db_session_with_containers, account): + def _create_test_dataset(self, db_session_with_containers: Session, account): """ Helper method to create a test dataset for testing. @@ -119,12 +119,12 @@ class TestBatchCleanDocumentTask: embedding_model_provider="openai", ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account): + def _create_test_document(self, db_session_with_containers: Session, dataset, account): """ Helper method to create a test document for testing. @@ -153,12 +153,12 @@ class TestBatchCleanDocumentTask: doc_form="text_model", ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_document_segment(self, db_session_with_containers, document, account): + def _create_test_document_segment(self, db_session_with_containers: Session, document, account): """ Helper method to create a test document segment for testing. @@ -186,12 +186,12 @@ class TestBatchCleanDocumentTask: status="completed", ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def _create_test_upload_file(self, db_session_with_containers, account): + def _create_test_upload_file(self, db_session_with_containers: Session, account): """ Helper method to create a test upload file for testing. @@ -220,13 +220,13 @@ class TestBatchCleanDocumentTask: used=False, ) - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file def test_batch_clean_document_task_successful_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful cleanup of documents with segments and files. @@ -245,7 +245,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -261,18 +261,18 @@ class TestBatchCleanDocumentTask: # The task should have processed the segment and cleaned up the database # Verify database cleanup - db.session.commit() # Ensure all changes are committed + db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_with_image_files( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of documents containing image references. @@ -300,8 +300,8 @@ class TestBatchCleanDocumentTask: status="completed", ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Store original IDs for verification segment_id = segment.id @@ -313,17 +313,17 @@ class TestBatchCleanDocumentTask: ) # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Verify that the task completed successfully by checking the log output # The task should have processed the segment and cleaned up the database def test_batch_clean_document_task_no_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when document has no segments. @@ -339,7 +339,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -354,21 +354,21 @@ class TestBatchCleanDocumentTask: # Since there are no segments, the task should handle this gracefully # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when dataset is not found. @@ -386,8 +386,8 @@ class TestBatchCleanDocumentTask: dataset_id = dataset.id # Delete the dataset to simulate not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Execute the task with non-existent dataset batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) @@ -399,14 +399,14 @@ class TestBatchCleanDocumentTask: mock_external_service_dependencies["storage"].delete.assert_not_called() # Verify that no database cleanup occurred - db.session.commit() + db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db.session.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when storage operations fail. @@ -423,7 +423,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -442,18 +442,18 @@ class TestBatchCleanDocumentTask: # The task should continue processing even when storage operations fail # Verify database cleanup still occurred despite storage failure - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_multiple_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of multiple documents in a single batch operation. @@ -482,7 +482,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -498,20 +498,20 @@ class TestBatchCleanDocumentTask: # The task should process all documents and clean up all associated resources # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup with different document form types. @@ -527,12 +527,12 @@ class TestBatchCleanDocumentTask: for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) - db.session.commit() + db_session_with_containers.commit() document = self._create_test_document(db_session_with_containers, dataset, account) # Update document doc_form document.doc_form = doc_form - db.session.commit() + db_session_with_containers.commit() segment = self._create_test_document_segment(db_session_with_containers, document, account) @@ -549,20 +549,20 @@ class TestBatchCleanDocumentTask: # The task should handle different document forms correctly # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None except Exception as e: # If the task fails due to external service issues (e.g., plugin daemon), # we should still verify that the database state is consistent # This is a common scenario in test environments where external services may not be available - db.session.commit() + db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -572,7 +572,7 @@ class TestBatchCleanDocumentTask: pass def test_batch_clean_document_task_large_batch_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup performance with a large batch of documents. @@ -604,7 +604,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -629,20 +629,20 @@ class TestBatchCleanDocumentTask: # The task should handle large batches efficiently # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test full integration with real database operations. @@ -683,12 +683,12 @@ class TestBatchCleanDocumentTask: # Add all to database for segment in segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Verify initial state - assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 + assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None # Store original IDs for verification document_id = document.id @@ -704,17 +704,17 @@ class TestBatchCleanDocumentTask: # The task should process all segments and clean up all associated resources # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify final database state - assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db.session.query(UploadFile).filter_by(id=file_id).first() is None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 + assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 2156743c17..a2324979db 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -29,20 +30,19 @@ class TestBatchCreateSegmentToIndexTask: """Integration tests for batch_create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" - from extensions.ext_database import db from extensions.ext_redis import redis_client # Clear all test data - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(UploadFile).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(UploadFile).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -75,7 +75,7 @@ class TestBatchCreateSegmentToIndexTask: "embedding_model": mock_embedding_model, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -95,18 +95,16 @@ class TestBatchCreateSegmentToIndexTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -115,15 +113,15 @@ class TestBatchCreateSegmentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -148,14 +146,12 @@ class TestBatchCreateSegmentToIndexTask: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -186,14 +182,12 @@ class TestBatchCreateSegmentToIndexTask: word_count=0, ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -221,10 +215,8 @@ class TestBatchCreateSegmentToIndexTask: used=False, ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file @@ -252,7 +244,7 @@ class TestBatchCreateSegmentToIndexTask: return csv_content def test_batch_create_segment_to_index_task_success_text_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful batch creation of segments for text model documents. @@ -293,11 +285,10 @@ class TestBatchCreateSegmentToIndexTask: ) # Verify results - from extensions.ext_database import db # Check that segments were created segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -316,7 +307,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.answer is None # text_model doesn't have answers # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called @@ -331,7 +322,7 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"completed" def test_batch_create_segment_to_index_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when dataset does not exist. @@ -370,17 +361,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created (since dataset doesn't exist) - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify no documents were modified - documents = db.session.query(Document).all() + documents = db_session_with_containers.query(Document).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document does not exist. @@ -419,18 +409,17 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) - db.session.refresh(dataset) - segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + db_session_with_containers.refresh(dataset) + segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document is not available for indexing. @@ -498,11 +487,9 @@ class TestBatchCreateSegmentToIndexTask: ), ] - from extensions.ext_database import db - for document in test_cases: - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Test each unavailable document for document in test_cases: @@ -524,11 +511,11 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when upload file does not exist. @@ -567,17 +554,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_empty_csv_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when CSV file is empty. @@ -619,17 +605,16 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_position_calculation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test proper position calculation for segments when existing segments exist. @@ -664,11 +649,9 @@ class TestBatchCreateSegmentToIndexTask: ) existing_segments.append(segment) - from extensions.ext_database import db - for segment in existing_segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Create CSV content csv_content = self._create_test_csv_content("text_model") @@ -695,7 +678,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions all_segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -716,7 +699,7 @@ class TestBatchCreateSegmentToIndexTask: assert segment.completed_at is not None # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index cd99b2965f..8eb881258a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( @@ -37,7 +38,7 @@ class TestCleanDatasetTask: """Integration tests for clean_dataset_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -82,7 +83,7 @@ class TestCleanDatasetTask: "index_processor": mock_index_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -127,7 +128,7 @@ class TestCleanDatasetTask: return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -157,7 +158,7 @@ class TestCleanDatasetTask: return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -194,7 +195,7 @@ class TestCleanDatasetTask: return document - def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document): + def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document): """ Helper method to create a test document segment for testing. @@ -230,7 +231,7 @@ class TestCleanDatasetTask: return segment - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -264,7 +265,7 @@ class TestCleanDatasetTask: return upload_file def test_clean_dataset_task_success_basic_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful basic dataset cleanup with minimal data. @@ -325,7 +326,7 @@ class TestCleanDatasetTask: mock_storage.delete.assert_not_called() def test_clean_dataset_task_success_with_documents_and_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with documents and segments. @@ -433,7 +434,7 @@ class TestCleanDatasetTask: assert mock_storage.delete.call_count == 3 def test_clean_dataset_task_success_with_invalid_doc_form( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with invalid doc_form handling. @@ -493,7 +494,7 @@ class TestCleanDatasetTask: assert mock_factory.call_count == 4 def test_clean_dataset_task_error_handling_and_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling and rollback mechanism when database operations fail. @@ -542,7 +543,7 @@ class TestCleanDatasetTask: # This demonstrates the resilience of the cleanup process def test_clean_dataset_task_with_image_file_references( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with image file references in document segments. @@ -634,7 +635,7 @@ class TestCleanDatasetTask: mock_get_image_ids.assert_called_once() def test_clean_dataset_task_performance_with_large_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup performance with large amounts of data. @@ -704,11 +705,9 @@ class TestCleanDatasetTask: binding.created_at = datetime.now() bindings.append(binding) - from extensions.ext_database import db - - db.session.add_all(metadata_items) - db.session.add_all(bindings) - db.session.commit() + db_session_with_containers.add_all(metadata_items) + db_session_with_containers.add_all(bindings) + db_session_with_containers.commit() # Measure cleanup performance import time @@ -772,7 +771,7 @@ class TestCleanDatasetTask: print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds") def test_clean_dataset_task_storage_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup when storage operations fail. @@ -838,7 +837,7 @@ class TestCleanDatasetTask: # consistency in the database def test_clean_dataset_task_edge_cases_and_boundary_conditions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with edge cases and boundary conditions. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index 8785c948d1..ab9e5b639a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -13,8 +13,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -34,7 +34,7 @@ class TestDisableSegmentFromIndexTask: mock_processor.clean.return_value = None yield mock_processor - def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -53,8 +53,8 @@ class TestDisableSegmentFromIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +62,8 @@ class TestDisableSegmentFromIndexTask: status="normal", plan="basic", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +72,15 @@ class TestDisableSegmentFromIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset: """ Helper method to create a test dataset. @@ -101,13 +101,18 @@ class TestDisableSegmentFromIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset def _create_test_document( - self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + self, + db_session_with_containers: Session, + dataset: Dataset, + tenant: Tenant, + account: Account, + doc_form: str = "text_model", ) -> Document: """ Helper method to create a test document. @@ -140,13 +145,14 @@ class TestDisableSegmentFromIndexTask: tokens=500, completed_at=datetime.now(UTC), ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document def _create_test_segment( self, + db_session_with_containers: Session, document: Document, dataset: Dataset, tenant: Tenant, @@ -185,12 +191,12 @@ class TestDisableSegmentFromIndexTask: created_by=account.id, completed_at=datetime.now(UTC) if status == "completed" else None, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor): """ Test successful segment disabling from index. @@ -202,9 +208,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -226,10 +232,10 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Verify segment is still in database - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not found. @@ -251,7 +257,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not in completed status. @@ -262,9 +268,11 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with non-completed segment account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment( + db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True + ) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -275,7 +283,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated dataset. @@ -286,13 +294,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove dataset association segment.dataset_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -303,7 +311,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated document. @@ -314,13 +322,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove document association segment.document_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -331,7 +339,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is disabled. @@ -342,12 +350,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with disabled document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.enabled = False - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -358,7 +366,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is archived. @@ -369,12 +377,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with archived document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.archived = True - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -385,7 +393,9 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_indexing_not_completed( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test handling when document indexing is not completed. @@ -396,12 +406,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with incomplete indexing account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -412,7 +422,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when index processor raises an exception. @@ -424,9 +434,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -449,13 +459,13 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][1] == [segment.index_node_id] # Check index node IDs # Verify segment was re-enabled - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify Redis cache was still cleared assert redis_client.get(indexing_cache_key) is None - def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor): """ Test disabling segments with different document forms. @@ -470,9 +480,11 @@ class TestDisableSegmentFromIndexTask: for doc_form in doc_forms: # Arrange: Create test data for each form account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document( + db_session_with_containers, dataset, tenant, account, doc_form=doc_form + ) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Reset mock for each iteration mock_index_processor.reset_mock() @@ -489,7 +501,7 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][0].id == dataset.id # Check dataset ID assert call_args[0][1] == [segment.index_node_id] # Check index node IDs - def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor): """ Test Redis cache handling during segment disabling. @@ -500,9 +512,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Test with cache present indexing_cache_key = f"segment_{segment.id}_indexing" @@ -517,13 +529,13 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Test with no cache present - segment2 = self._create_test_segment(document, dataset, tenant, account) + segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) result2 = disable_segment_from_index_task(segment2.id) # Assert: Verify task still works without cache assert result2 is None - def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor): """ Test performance timing of segment disabling task. @@ -534,9 +546,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task and measure time start_time = time.perf_counter() @@ -548,7 +560,9 @@ class TestDisableSegmentFromIndexTask: execution_time = end_time - start_time assert execution_time < 5.0 # Should complete within 5 seconds - def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_database_session_management( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test database session management during task execution. @@ -559,9 +573,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -570,10 +584,10 @@ class TestDisableSegmentFromIndexTask: assert result is None # Verify segment is still accessible (session was properly managed) - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor): """ Test concurrent execution of segment disabling tasks. @@ -584,12 +598,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create multiple test segments account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) segments = [] for i in range(3): - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) segments.append(segment) # Act: Execute tasks concurrently (simulated) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index a93a80e231..8f47b48ae2 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument @@ -31,7 +32,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -79,7 +80,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): """ Helper method to create a test dataset with realistic data. @@ -113,7 +114,7 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): """ Helper method to create a test document with realistic data. @@ -158,7 +159,9 @@ class TestDisableSegmentsFromIndexTask: return document - def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + def _create_test_segments( + self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None + ): """ Helper method to create test document segments with realistic data. @@ -210,7 +213,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): """ Helper method to create a dataset process rule. @@ -239,14 +242,12 @@ class TestDisableSegmentsFromIndexTask: process_rule.created_by = dataset.created_by process_rule.updated_by = dataset.updated_by - from extensions.ext_database import db - - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule - def test_disable_segments_success(self, db_session_with_containers): + def test_disable_segments_success(self, db_session_with_containers: Session): """ Test successful disabling of segments from index. @@ -297,7 +298,7 @@ class TestDisableSegmentsFromIndexTask: expected_key = f"segment_{segment.id}_indexing" mock_redis.delete.assert_any_call(expected_key) - def test_disable_segments_dataset_not_found(self, db_session_with_containers): + def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -320,7 +321,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when dataset is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_not_found(self, db_session_with_containers): + def test_disable_segments_document_not_found(self, db_session_with_containers: Session): """ Test handling when document is not found. @@ -344,7 +345,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when document is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_invalid_status(self, db_session_with_containers): + def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session): """ Test handling when document has invalid status for disabling. @@ -360,9 +361,8 @@ class TestDisableSegmentsFromIndexTask: # Test case 1: Document not enabled document.enabled = False - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() segment_ids = [segment.id for segment in segments] @@ -379,7 +379,7 @@ class TestDisableSegmentsFromIndexTask: # Test case 2: Document archived document.enabled = True document.archived = True - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -393,7 +393,7 @@ class TestDisableSegmentsFromIndexTask: document.enabled = True document.archived = False document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -403,7 +403,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_redis.delete.assert_not_called() - def test_disable_segments_no_segments_found(self, db_session_with_containers): + def test_disable_segments_no_segments_found(self, db_session_with_containers: Session): """ Test handling when no segments are found for the given IDs. @@ -430,7 +430,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are found mock_redis.delete.assert_not_called() - def test_disable_segments_index_processor_error(self, db_session_with_containers): + def test_disable_segments_index_processor_error(self, db_session_with_containers: Session): """ Test handling when index processor encounters an error. @@ -464,13 +464,14 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Verify segments were rolled back to enabled state - from extensions.ext_database import db - db.session.refresh(segments[0]) - db.session.refresh(segments[1]) + db_session_with_containers.refresh(segments[0]) + db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + updated_segments = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + ) for segment in updated_segments: assert segment.enabled is True @@ -480,7 +481,7 @@ class TestDisableSegmentsFromIndexTask: # Verify Redis cache cleanup was still called assert mock_redis.delete.call_count == len(segments) - def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session): """ Test disabling segments with different document forms. @@ -503,9 +504,8 @@ class TestDisableSegmentsFromIndexTask: for doc_form in doc_forms: # Update document form document.doc_form = doc_form - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock the index processor factory with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: @@ -523,7 +523,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_factory.assert_called_with(doc_form) - def test_disable_segments_performance_timing(self, db_session_with_containers): + def test_disable_segments_performance_timing(self, db_session_with_containers: Session): """ Test that the task properly measures and logs performance timing. @@ -568,7 +568,7 @@ class TestDisableSegmentsFromIndexTask: assert performance_log is not None assert "0.5" in performance_log # Should log the execution time - def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session): """ Test that Redis cache is properly cleaned up for all segments. @@ -610,7 +610,7 @@ class TestDisableSegmentsFromIndexTask: for expected_key in expected_keys: assert expected_key in actual_calls - def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session): """ Test that database session is properly closed after task execution. @@ -643,7 +643,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Session lifecycle is managed by context manager; no explicit close assertion - def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session): """ Test handling when empty segment IDs list is provided. @@ -669,7 +669,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are provided mock_redis.delete.assert_not_called() - def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session): """ Test handling when some segment IDs are valid and others are invalid. diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index b3d9e49b30..bc29395545 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment @@ -31,7 +31,9 @@ class TestEnableSegmentsToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +53,15 @@ class TestEnableSegmentsToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +70,8 @@ class TestEnableSegmentsToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -81,8 +83,8 @@ class TestEnableSegmentsToIndexTask: indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -99,16 +101,16 @@ class TestEnableSegmentsToIndexTask: enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document def _create_test_segments( - self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed" + self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed" ): """ Helper method to create test document segments. @@ -144,14 +146,14 @@ class TestEnableSegmentsToIndexTask: status=status, created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments def test_enable_segments_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with different index types. @@ -169,10 +171,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -204,7 +206,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -229,7 +231,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -256,7 +258,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_invalid_document_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid status. @@ -284,12 +286,12 @@ class TestEnableSegmentsToIndexTask: document.enabled = True document.archived = False document.indexing_status = "completed" - db.session.commit() + db_session_with_containers.commit() # Set invalid status for attr, value in status_attrs.items(): setattr(document, attr, value) - db.session.commit() + db_session_with_containers.commit() # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -304,11 +306,11 @@ class TestEnableSegmentsToIndexTask: # Clean up segments for next iteration for segment in segments: - db.session.delete(segment) - db.session.commit() + db_session_with_containers.delete(segment) + db_session_with_containers.commit() def test_enable_segments_to_index_segments_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when no segments are found. @@ -338,7 +340,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with parent-child structure. @@ -357,10 +359,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -410,7 +412,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -443,7 +445,7 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify error handling for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False assert segment.status == "error" assert segment.error is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py index 6c3a9ef20a..ff72232d12 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -2,8 +2,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task @@ -30,7 +30,7 @@ class TestMailAccountDeletionTask: "email_service": mock_email_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -49,16 +49,16 @@ class TestMailAccountDeletionTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -67,12 +67,14 @@ class TestMailAccountDeletionTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return account - def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_deletion_success_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful account deletion success email sending. @@ -109,7 +111,7 @@ class TestMailAccountDeletionTask: ) def test_send_deletion_success_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when mail service is not initialized. @@ -132,7 +134,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_deletion_success_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when email service raises exception. @@ -154,7 +156,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_called_once() def test_send_account_deletion_verification_code_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful account deletion verification code email sending. @@ -193,7 +195,7 @@ class TestMailAccountDeletionTask: ) def test_send_account_deletion_verification_code_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when mail service is not initialized. @@ -217,7 +219,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_account_deletion_verification_code_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when email service raises exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index b9977b1fb6..ef7191299a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,11 +4,11 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Pipeline from models.workflow import Workflow @@ -52,7 +52,7 @@ class TestRagPipelineRunTasks: "delete_file": mock_delete_file, } - def _create_test_pipeline_and_workflow(self, db_session_with_containers): + def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session): """ Helper method to create test pipeline and workflow for testing. @@ -71,15 +71,15 @@ class TestRagPipelineRunTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,8 +88,8 @@ class TestRagPipelineRunTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create workflow workflow = Workflow( @@ -107,8 +107,8 @@ class TestRagPipelineRunTasks: conversation_variables=[], rag_pipeline_variables=[], ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create pipeline pipeline = Pipeline( @@ -119,14 +119,14 @@ class TestRagPipelineRunTasks: created_by=account.id, ) pipeline.id = str(uuid.uuid4()) - db.session.add(pipeline) - db.session.commit() + db_session_with_containers.add(pipeline) + db_session_with_containers.commit() # Refresh entities to ensure they're properly loaded - db.session.refresh(account) - db.session.refresh(tenant) - db.session.refresh(workflow) - db.session.refresh(pipeline) + db_session_with_containers.refresh(account) + db_session_with_containers.refresh(tenant) + db_session_with_containers.refresh(workflow) + db_session_with_containers.refresh(pipeline) return account, tenant, pipeline, workflow @@ -209,7 +209,7 @@ class TestRagPipelineRunTasks: return json.dumps(entities_data) def test_priority_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful priority RAG pipeline run task execution. @@ -254,7 +254,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful regular RAG pipeline run task execution. @@ -299,7 +299,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_priority_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with waiting tasks in queue using real Redis. @@ -351,7 +351,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining def test_rag_pipeline_run_task_legacy_compatibility( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. @@ -419,7 +419,7 @@ class TestRagPipelineRunTasks: redis_client.delete(legacy_task_key) def test_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with waiting tasks in queue using real Redis. @@ -469,7 +469,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining def test_priority_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in priority RAG pipeline run task using real Redis. @@ -526,7 +526,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in regular RAG pipeline run task using real Redis. @@ -581,7 +581,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_priority_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in priority RAG pipeline run task using real Redis. @@ -648,7 +648,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in regular RAG pipeline run task using real Redis. @@ -713,7 +713,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test successful run_single_rag_pipeline_task execution. @@ -748,7 +748,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -793,7 +793,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with non-existent database entities. @@ -838,7 +838,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_priority_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with non-existent file. @@ -888,7 +888,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with non-existent file.