mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
feat: replace db.session with db_session_with_containers (#32942)
This commit is contained in:
parent
2f4c740d46
commit
ad000c42b7
@ -9,8 +9,8 @@ from itertools import starmap
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
|
||||
@ -28,6 +28,7 @@ class DatasetCollectionBindingTestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding(
|
||||
db_session_with_containers: Session,
|
||||
provider_name: str = "openai",
|
||||
model_name: str = "text-embedding-ada-002",
|
||||
collection_name: str = "collection-abc",
|
||||
@ -51,8 +52,8 @@ class DatasetCollectionBindingTestDataFactory:
|
||||
collection_name=collection_name,
|
||||
type=collection_type,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
return binding
|
||||
|
||||
|
||||
@ -64,7 +65,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
including various provider/model combinations, collection types, and edge cases.
|
||||
"""
|
||||
|
||||
def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful retrieval of an existing collection binding.
|
||||
|
||||
@ -77,6 +78,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
model_name = "text-embedding-ada-002"
|
||||
collection_type = "dataset"
|
||||
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
collection_name="existing-collection",
|
||||
@ -92,7 +94,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
assert result.id == existing_binding.id
|
||||
assert result.collection_name == "existing-collection"
|
||||
|
||||
def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful creation of a new collection binding when none exists.
|
||||
|
||||
@ -116,7 +118,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
assert result.type == collection_type
|
||||
assert result.collection_name is not None
|
||||
|
||||
def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session):
|
||||
"""Test get_dataset_collection_binding with different collection type."""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
@ -133,7 +135,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
|
||||
def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session):
|
||||
"""Test get_dataset_collection_binding with default collection type parameter."""
|
||||
# Arrange
|
||||
provider_name = "openai"
|
||||
@ -147,7 +149,9 @@ class TestDatasetCollectionBindingServiceGetBinding:
|
||||
assert result.provider_name == provider_name
|
||||
assert result.model_name == model_name
|
||||
|
||||
def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_different_provider_model_combination(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""Test get_dataset_collection_binding with various provider/model combinations."""
|
||||
# Arrange
|
||||
combinations = [
|
||||
@ -174,10 +178,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
including successful retrieval and error handling for missing bindings.
|
||||
"""
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session):
|
||||
"""Test successful retrieval of collection binding by ID and type."""
|
||||
# Arrange
|
||||
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
@ -194,7 +199,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
assert result.collection_name == "test-collection"
|
||||
assert result.type == "dataset"
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
|
||||
"""Test error handling when collection binding is not found by ID and type."""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid4())
|
||||
@ -203,10 +208,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""Test retrieval by ID and type with different collection type."""
|
||||
# Arrange
|
||||
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
@ -222,10 +230,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
assert result.id == binding.id
|
||||
assert result.type == "custom_type"
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""Test retrieval by ID with default collection type."""
|
||||
# Arrange
|
||||
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
@ -239,10 +250,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
|
||||
assert result.id == binding.id
|
||||
assert result.type == "dataset"
|
||||
|
||||
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers):
|
||||
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
|
||||
"""Test error when binding exists but with wrong collection type."""
|
||||
# Arrange
|
||||
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
|
||||
db_session_with_containers,
|
||||
provider_name="openai",
|
||||
model_name="text-embedding-ada-002",
|
||||
collection_name="test-collection",
|
||||
|
||||
@ -10,9 +10,9 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
|
||||
from models.model import App
|
||||
@ -27,6 +27,7 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session,
|
||||
role: TenantAccountRole = TenantAccountRole.NORMAL,
|
||||
tenant: Tenant | None = None,
|
||||
) -> tuple[Account, Tenant]:
|
||||
@ -37,13 +38,13 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
if tenant is None:
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
@ -51,14 +52,15 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
name: str = "Test Dataset",
|
||||
@ -78,12 +80,12 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
retrieval_model={"top_k": 2},
|
||||
enable_api=enable_api,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App:
|
||||
def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App:
|
||||
"""Create a real app for AppDatasetJoin."""
|
||||
app = App(
|
||||
tenant_id=tenant_id,
|
||||
@ -96,16 +98,16 @@ class DatasetUpdateDeleteTestDataFactory:
|
||||
enable_api=True,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin:
|
||||
def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin:
|
||||
"""Create a real AppDatasetJoin record."""
|
||||
join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
return join
|
||||
|
||||
|
||||
@ -114,7 +116,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
Comprehensive integration tests for DatasetService.delete_dataset method.
|
||||
"""
|
||||
|
||||
def test_delete_dataset_success(self, db_session_with_containers):
|
||||
def test_delete_dataset_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful deletion of a dataset.
|
||||
|
||||
@ -130,8 +132,10 @@ class TestDatasetServiceDeleteDataset:
|
||||
- Method returns True
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted:
|
||||
@ -139,10 +143,10 @@ class TestDatasetServiceDeleteDataset:
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
mock_dataset_was_deleted.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_dataset_not_found(self, db_session_with_containers):
|
||||
def test_delete_dataset_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when dataset is not found.
|
||||
|
||||
@ -156,7 +160,9 @@ class TestDatasetServiceDeleteDataset:
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange
|
||||
owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
@ -165,7 +171,7 @@ class TestDatasetServiceDeleteDataset:
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_delete_dataset_permission_denied_error(self, db_session_with_containers):
|
||||
def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error handling when user lacks permission.
|
||||
|
||||
@ -178,19 +184,22 @@ class TestDatasetServiceDeleteDataset:
|
||||
- No database operations are performed
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers,
|
||||
role=TenantAccountRole.NORMAL,
|
||||
tenant=tenant,
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.delete_dataset(dataset.id, normal_user)
|
||||
|
||||
# Verify no deletion was attempted
|
||||
assert db.session.get(Dataset, dataset.id) is not None
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is not None
|
||||
|
||||
|
||||
class TestDatasetServiceDatasetUseCheck:
|
||||
@ -198,7 +207,7 @@ class TestDatasetServiceDatasetUseCheck:
|
||||
Comprehensive integration tests for DatasetService.dataset_use_check method.
|
||||
"""
|
||||
|
||||
def test_dataset_use_check_in_use(self, db_session_with_containers):
|
||||
def test_dataset_use_check_in_use(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test detection when dataset is in use.
|
||||
|
||||
@ -211,10 +220,12 @@ class TestDatasetServiceDatasetUseCheck:
|
||||
- Database query is executed
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id)
|
||||
DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id)
|
||||
DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id)
|
||||
|
||||
# Act
|
||||
result = DatasetService.dataset_use_check(dataset.id)
|
||||
@ -222,7 +233,7 @@ class TestDatasetServiceDatasetUseCheck:
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_dataset_use_check_not_in_use(self, db_session_with_containers):
|
||||
def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test detection when dataset is not in use.
|
||||
|
||||
@ -235,8 +246,10 @@ class TestDatasetServiceDatasetUseCheck:
|
||||
- Database query is executed
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
|
||||
# Act
|
||||
result = DatasetService.dataset_use_check(dataset.id)
|
||||
@ -250,7 +263,7 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
Comprehensive integration tests for DatasetService.update_dataset_api_status method.
|
||||
"""
|
||||
|
||||
def test_update_dataset_api_status_enable_success(self, db_session_with_containers):
|
||||
def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful enabling of dataset API access.
|
||||
|
||||
@ -264,8 +277,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
- Transaction is committed
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant.id, owner.id, enable_api=False
|
||||
)
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
|
||||
# Act
|
||||
@ -276,12 +293,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
DatasetService.update_dataset_api_status(dataset.id, True)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.enable_api is True
|
||||
assert dataset.updated_by == owner.id
|
||||
assert dataset.updated_at == current_time
|
||||
|
||||
def test_update_dataset_api_status_disable_success(self, db_session_with_containers):
|
||||
def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful disabling of dataset API access.
|
||||
|
||||
@ -295,8 +312,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
- Transaction is committed
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant.id, owner.id, enable_api=True
|
||||
)
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
|
||||
# Act
|
||||
@ -307,11 +328,11 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
DatasetService.update_dataset_api_status(dataset.id, False)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.enable_api is False
|
||||
assert dataset.updated_by == owner.id
|
||||
|
||||
def test_update_dataset_api_status_not_found_error(self, db_session_with_containers):
|
||||
def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error handling when dataset is not found.
|
||||
|
||||
@ -330,7 +351,7 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
DatasetService.update_dataset_api_status(dataset_id, True)
|
||||
|
||||
def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers):
|
||||
def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error handling when current_user is missing.
|
||||
|
||||
@ -343,8 +364,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
- No updates are committed
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False)
|
||||
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant.id, owner.id, enable_api=False
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with (
|
||||
@ -354,6 +379,6 @@ class TestDatasetServiceUpdateDatasetApiStatus:
|
||||
DatasetService.update_dataset_api_status(dataset.id, True)
|
||||
|
||||
# Verify no commit was attempted
|
||||
db.session.rollback()
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.rollback()
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.enable_api is False
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
@ -87,7 +88,7 @@ class TestAgentService:
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -133,13 +134,12 @@ class TestAgentService:
|
||||
# Update the app model config to set agent_mode for agent-chat mode
|
||||
if app.mode == "agent-chat" and app.app_model_config:
|
||||
app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []})
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_conversation_and_message(self, db_session_with_containers, app, account):
|
||||
def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account):
|
||||
"""
|
||||
Helper method to create a test conversation and message with agent thoughts.
|
||||
|
||||
@ -153,8 +153,6 @@ class TestAgentService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create conversation
|
||||
conversation = Conversation(
|
||||
id=fake.uuid4(),
|
||||
@ -167,8 +165,8 @@ class TestAgentService:
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
@ -180,12 +178,12 @@ class TestAgentService:
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app_model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update conversation with app model config
|
||||
conversation.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
@ -206,12 +204,12 @@ class TestAgentService:
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return conversation, message
|
||||
|
||||
def _create_test_agent_thoughts(self, db_session_with_containers, message):
|
||||
def _create_test_agent_thoughts(self, db_session_with_containers: Session, message):
|
||||
"""
|
||||
Helper method to create test agent thoughts for a message.
|
||||
|
||||
@ -224,8 +222,6 @@ class TestAgentService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
agent_thoughts = []
|
||||
|
||||
# Create first agent thought
|
||||
@ -251,7 +247,7 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(thought1)
|
||||
db_session_with_containers.add(thought1)
|
||||
agent_thoughts.append(thought1)
|
||||
|
||||
# Create second agent thought
|
||||
@ -277,14 +273,14 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(thought2)
|
||||
db_session_with_containers.add(thought2)
|
||||
agent_thoughts.append(thought2)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return agent_thoughts
|
||||
|
||||
def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of agent logs with complete data.
|
||||
"""
|
||||
@ -344,7 +340,7 @@ class TestAgentService:
|
||||
assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon
|
||||
|
||||
def test_get_agent_logs_conversation_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when conversation is not found.
|
||||
@ -358,7 +354,9 @@ class TestAgentService:
|
||||
with pytest.raises(ValueError, match="Conversation not found"):
|
||||
AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4())
|
||||
|
||||
def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_message_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when message is not found.
|
||||
"""
|
||||
@ -372,7 +370,9 @@ class TestAgentService:
|
||||
with pytest.raises(ValueError, match="Message not found"):
|
||||
AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4())
|
||||
|
||||
def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_with_end_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval when conversation is from end user.
|
||||
"""
|
||||
@ -381,8 +381,6 @@ class TestAgentService:
|
||||
# Create test data
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create end user
|
||||
end_user = EndUser(
|
||||
id=fake.uuid4(),
|
||||
@ -393,8 +391,8 @@ class TestAgentService:
|
||||
session_id=fake.uuid4(),
|
||||
name=fake.name(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create conversation with end user
|
||||
conversation = Conversation(
|
||||
@ -408,8 +406,8 @@ class TestAgentService:
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
@ -421,12 +419,12 @@ class TestAgentService:
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app_model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update conversation with app model config
|
||||
conversation.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
@ -447,8 +445,8 @@ class TestAgentService:
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
@ -457,7 +455,9 @@ class TestAgentService:
|
||||
assert result is not None
|
||||
assert result["meta"]["executor"] == end_user.name
|
||||
|
||||
def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_with_unknown_executor(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval when executor is unknown.
|
||||
"""
|
||||
@ -466,8 +466,6 @@ class TestAgentService:
|
||||
# Create test data
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create conversation with non-existent account
|
||||
conversation = Conversation(
|
||||
id=fake.uuid4(),
|
||||
@ -480,8 +478,8 @@ class TestAgentService:
|
||||
mode="chat",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
@ -493,12 +491,12 @@ class TestAgentService:
|
||||
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app_model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update conversation with app model config
|
||||
conversation.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
@ -519,8 +517,8 @@ class TestAgentService:
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
@ -529,7 +527,9 @@ class TestAgentService:
|
||||
assert result is not None
|
||||
assert result["meta"]["executor"] == "Unknown"
|
||||
|
||||
def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_with_tool_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval with tool errors.
|
||||
"""
|
||||
@ -539,8 +539,6 @@ class TestAgentService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create agent thought with tool error
|
||||
thought_with_error = MessageAgentThought(
|
||||
message_id=message.id,
|
||||
@ -564,8 +562,8 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(thought_with_error)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(thought_with_error)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
@ -580,7 +578,7 @@ class TestAgentService:
|
||||
assert tool_call["error"] == "Tool execution failed"
|
||||
|
||||
def test_get_agent_logs_without_agent_thoughts(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval when message has no agent thoughts.
|
||||
@ -600,7 +598,7 @@ class TestAgentService:
|
||||
assert len(result["iterations"]) == 0
|
||||
|
||||
def test_get_agent_logs_app_model_config_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when app model config is not found.
|
||||
@ -610,11 +608,9 @@ class TestAgentService:
|
||||
# Create test data
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Remove app model config to test error handling
|
||||
app.app_model_config_id = None
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create conversation without app model config
|
||||
conversation = Conversation(
|
||||
@ -629,8 +625,8 @@ class TestAgentService:
|
||||
from_source="api",
|
||||
app_model_config_id=None, # Explicitly set to None
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
@ -651,15 +647,15 @@ class TestAgentService:
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
)
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
with pytest.raises(ValueError, match="App model config not found"):
|
||||
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
|
||||
def test_get_agent_logs_agent_config_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when agent config is not found.
|
||||
@ -677,7 +673,9 @@ class TestAgentService:
|
||||
with pytest.raises(ValueError, match="Agent config not found"):
|
||||
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
|
||||
def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_list_agent_providers_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful listing of agent providers.
|
||||
"""
|
||||
@ -698,7 +696,7 @@ class TestAgentService:
|
||||
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
|
||||
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id))
|
||||
|
||||
def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of specific agent provider.
|
||||
"""
|
||||
@ -720,7 +718,9 @@ class TestAgentService:
|
||||
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
|
||||
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name)
|
||||
|
||||
def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_provider_plugin_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when plugin daemon client raises an error.
|
||||
"""
|
||||
@ -741,7 +741,7 @@ class TestAgentService:
|
||||
AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
|
||||
|
||||
def test_get_agent_logs_with_complex_tool_data(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval with complex tool data and multiple tools.
|
||||
@ -752,8 +752,6 @@ class TestAgentService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create agent thought with multiple tools
|
||||
complex_thought = MessageAgentThought(
|
||||
message_id=message.id,
|
||||
@ -799,8 +797,8 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(complex_thought)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(complex_thought)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
@ -831,7 +829,7 @@ class TestAgentService:
|
||||
assert tool_calls[2]["status"] == "success"
|
||||
assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon
|
||||
|
||||
def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test agent logs retrieval with message files and agent thought files.
|
||||
"""
|
||||
@ -842,7 +840,6 @@ class TestAgentService:
|
||||
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
|
||||
|
||||
from dify_graph.file import FileTransferMethod, FileType
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
|
||||
# Add files to message
|
||||
@ -867,9 +864,9 @@ class TestAgentService:
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(message_file1)
|
||||
db.session.add(message_file2)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message_file1)
|
||||
db_session_with_containers.add(message_file2)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create agent thought with files
|
||||
thought_with_files = MessageAgentThought(
|
||||
@ -895,8 +892,8 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(thought_with_files)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(thought_with_files)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
@ -912,7 +909,7 @@ class TestAgentService:
|
||||
assert "file2" in iterations[0]["files"]
|
||||
|
||||
def test_get_agent_logs_with_different_timezone(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval with different timezone settings.
|
||||
@ -938,7 +935,9 @@ class TestAgentService:
|
||||
assert "T" in start_time # ISO format
|
||||
assert "+08:00" in start_time or "Z" in start_time # Timezone offset
|
||||
|
||||
def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_with_empty_tool_data(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval with empty tool data.
|
||||
"""
|
||||
@ -948,8 +947,6 @@ class TestAgentService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create agent thought with empty tool data
|
||||
empty_thought = MessageAgentThought(
|
||||
message_id=message.id,
|
||||
@ -964,8 +961,8 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(empty_thought)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(empty_thought)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
@ -979,7 +976,9 @@ class TestAgentService:
|
||||
tool_calls = iterations[0]["tool_calls"]
|
||||
assert len(tool_calls) == 0 # No tools to process
|
||||
|
||||
def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_agent_logs_with_malformed_json(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test agent logs retrieval with malformed JSON data in tool fields.
|
||||
"""
|
||||
@ -989,8 +988,6 @@ class TestAgentService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create agent thought with malformed JSON
|
||||
malformed_thought = MessageAgentThought(
|
||||
message_id=message.id,
|
||||
@ -1005,8 +1002,8 @@ class TestAgentService:
|
||||
created_by_role="account",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(malformed_thought)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(malformed_thought)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the method under test
|
||||
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account
|
||||
@ -52,7 +53,7 @@ class TestAnnotationService:
|
||||
"current_user": mock_user,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -115,11 +116,10 @@ class TestAnnotationService:
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
def _create_test_conversation(self, app, account, fake):
|
||||
def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake):
|
||||
"""
|
||||
Helper method to create a test conversation with all required fields.
|
||||
"""
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
@ -141,17 +141,16 @@ class TestAnnotationService:
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.flush()
|
||||
return conversation
|
||||
|
||||
def _create_test_message(self, app, conversation, account, fake):
|
||||
def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake):
|
||||
"""
|
||||
Helper method to create a test message with all required fields.
|
||||
"""
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
message = Message(
|
||||
@ -180,12 +179,12 @@ class TestAnnotationService:
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
return message
|
||||
|
||||
def test_insert_app_annotation_directly_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful direct insertion of app annotation.
|
||||
@ -211,9 +210,8 @@ class TestAnnotationService:
|
||||
assert annotation.id is not None
|
||||
|
||||
# Verify annotation was saved to database
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(annotation)
|
||||
db_session_with_containers.refresh(annotation)
|
||||
assert annotation.id is not None
|
||||
|
||||
# Verify add_annotation_to_index_task was called (when annotation setting exists)
|
||||
@ -221,7 +219,7 @@ class TestAnnotationService:
|
||||
mock_external_service_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_insert_app_annotation_directly_requires_question(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Question must be provided when inserting annotations directly.
|
||||
@ -238,7 +236,7 @@ class TestAnnotationService:
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
|
||||
|
||||
def test_insert_app_annotation_directly_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test direct insertion of app annotation when app is not found.
|
||||
@ -260,7 +258,7 @@ class TestAnnotationService:
|
||||
AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id)
|
||||
|
||||
def test_update_app_annotation_directly_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful direct update of app annotation.
|
||||
@ -298,7 +296,7 @@ class TestAnnotationService:
|
||||
mock_external_service_dependencies["update_task"].delay.assert_not_called()
|
||||
|
||||
def test_up_insert_app_annotation_from_message_new(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creating new annotation from message.
|
||||
@ -307,8 +305,8 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message first
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Setup annotation data with message_id
|
||||
annotation_args = {
|
||||
@ -333,7 +331,7 @@ class TestAnnotationService:
|
||||
mock_external_service_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_up_insert_app_annotation_from_message_update(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test updating existing annotation from message.
|
||||
@ -342,8 +340,8 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message first
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create initial annotation
|
||||
initial_args = {
|
||||
@ -373,7 +371,7 @@ class TestAnnotationService:
|
||||
mock_external_service_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_up_insert_app_annotation_from_message_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creating annotation from message when app is not found.
|
||||
@ -395,7 +393,7 @@ class TestAnnotationService:
|
||||
AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id)
|
||||
|
||||
def test_get_annotation_list_by_app_id_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of annotation list by app ID.
|
||||
@ -428,7 +426,7 @@ class TestAnnotationService:
|
||||
assert annotation.account_id == account.id
|
||||
|
||||
def test_get_annotation_list_by_app_id_with_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of annotation list with keyword search.
|
||||
@ -462,7 +460,7 @@ class TestAnnotationService:
|
||||
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
|
||||
|
||||
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
|
||||
@ -534,7 +532,7 @@ class TestAnnotationService:
|
||||
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
|
||||
|
||||
def test_get_annotation_list_by_app_id_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of annotation list when app is not found.
|
||||
@ -549,7 +547,9 @@ class TestAnnotationService:
|
||||
with pytest.raises(NotFound, match="App not found"):
|
||||
AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="")
|
||||
|
||||
def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_app_annotation_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful deletion of app annotation.
|
||||
"""
|
||||
@ -568,16 +568,19 @@ class TestAnnotationService:
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
|
||||
|
||||
# Verify annotation was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
deleted_annotation = (
|
||||
db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
)
|
||||
assert deleted_annotation is None
|
||||
|
||||
# Verify delete_annotation_index_task was called (when annotation setting exists)
|
||||
# Note: In this test, no annotation setting exists, so task should not be called
|
||||
mock_external_service_dependencies["delete_task"].delay.assert_not_called()
|
||||
|
||||
def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_app_annotation_app_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test deletion of app annotation when app is not found.
|
||||
"""
|
||||
@ -593,7 +596,7 @@ class TestAnnotationService:
|
||||
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id)
|
||||
|
||||
def test_delete_app_annotation_annotation_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test deletion of app annotation when annotation is not found.
|
||||
@ -606,7 +609,9 @@ class TestAnnotationService:
|
||||
with pytest.raises(NotFound, match="Annotation not found"):
|
||||
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id)
|
||||
|
||||
def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_enable_app_annotation_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful enabling of app annotation.
|
||||
"""
|
||||
@ -632,7 +637,9 @@ class TestAnnotationService:
|
||||
# Verify task was called
|
||||
mock_external_service_dependencies["enable_task"].delay.assert_called_once()
|
||||
|
||||
def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_disable_app_annotation_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful disabling of app annotation.
|
||||
"""
|
||||
@ -651,7 +658,9 @@ class TestAnnotationService:
|
||||
# Verify task was called
|
||||
mock_external_service_dependencies["disable_task"].delay.assert_called_once()
|
||||
|
||||
def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_enable_app_annotation_cached_job(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test enabling app annotation when job is already cached.
|
||||
"""
|
||||
@ -685,7 +694,9 @@ class TestAnnotationService:
|
||||
# Clean up
|
||||
redis_client.delete(enable_app_annotation_key)
|
||||
|
||||
def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_annotation_hit_histories_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of annotation hit histories.
|
||||
"""
|
||||
@ -728,7 +739,9 @@ class TestAnnotationService:
|
||||
assert history.app_id == app.id
|
||||
assert history.account_id == account.id
|
||||
|
||||
def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_add_annotation_history_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful addition of annotation history.
|
||||
"""
|
||||
@ -763,16 +776,15 @@ class TestAnnotationService:
|
||||
)
|
||||
|
||||
# Verify hit count was incremented
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(annotation)
|
||||
db_session_with_containers.refresh(annotation)
|
||||
assert annotation.hit_count == initial_hit_count + 1
|
||||
|
||||
# Verify history was created
|
||||
from models.model import AppAnnotationHitHistory
|
||||
|
||||
history = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
db_session_with_containers.query(AppAnnotationHitHistory)
|
||||
.where(
|
||||
AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id
|
||||
)
|
||||
@ -786,7 +798,9 @@ class TestAnnotationService:
|
||||
assert history.score == score
|
||||
assert history.source == "console"
|
||||
|
||||
def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_annotation_by_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of annotation by ID.
|
||||
"""
|
||||
@ -811,7 +825,9 @@ class TestAnnotationService:
|
||||
assert retrieved_annotation.content == annotation_args["answer"]
|
||||
assert retrieved_annotation.account_id == account.id
|
||||
|
||||
def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_batch_import_app_annotations_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful batch import of app annotations.
|
||||
"""
|
||||
@ -854,7 +870,7 @@ class TestAnnotationService:
|
||||
mock_external_service_dependencies["batch_import_task"].delay.assert_called_once()
|
||||
|
||||
def test_batch_import_app_annotations_empty_file(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test batch import with empty CSV file.
|
||||
@ -889,7 +905,7 @@ class TestAnnotationService:
|
||||
assert "empty" in result["error_msg"].lower()
|
||||
|
||||
def test_batch_import_app_annotations_quota_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test batch import when quota is exceeded.
|
||||
@ -935,7 +951,7 @@ class TestAnnotationService:
|
||||
assert "limit" in result["error_msg"].lower()
|
||||
|
||||
def test_get_app_annotation_setting_by_app_id_enabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting enabled app annotation setting by app ID.
|
||||
@ -944,7 +960,6 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotation setting
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.model import AppAnnotationSetting
|
||||
|
||||
@ -956,8 +971,8 @@ class TestAnnotationService:
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(collection_binding)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
@ -967,8 +982,8 @@ class TestAnnotationService:
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(annotation_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Get annotation setting
|
||||
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id)
|
||||
@ -981,7 +996,7 @@ class TestAnnotationService:
|
||||
assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002"
|
||||
|
||||
def test_get_app_annotation_setting_by_app_id_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting disabled app annotation setting by app ID.
|
||||
@ -996,7 +1011,7 @@ class TestAnnotationService:
|
||||
assert result["enabled"] is False
|
||||
|
||||
def test_update_app_annotation_setting_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of app annotation setting.
|
||||
@ -1005,7 +1020,6 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotation setting first
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.model import AppAnnotationSetting
|
||||
|
||||
@ -1017,8 +1031,8 @@ class TestAnnotationService:
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(collection_binding)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
@ -1028,8 +1042,8 @@ class TestAnnotationService:
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(annotation_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update annotation setting
|
||||
update_args = {
|
||||
@ -1046,11 +1060,11 @@ class TestAnnotationService:
|
||||
assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002"
|
||||
|
||||
# Verify database was updated
|
||||
db.session.refresh(annotation_setting)
|
||||
db_session_with_containers.refresh(annotation_setting)
|
||||
assert annotation_setting.score_threshold == 0.9
|
||||
|
||||
def test_export_annotation_list_by_app_id_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful export of annotation list by app ID.
|
||||
@ -1083,7 +1097,7 @@ class TestAnnotationService:
|
||||
assert annotation.created_at <= exported_annotations[i - 1].created_at
|
||||
|
||||
def test_export_annotation_list_by_app_id_app_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test export of annotation list when app is not found.
|
||||
@ -1099,7 +1113,7 @@ class TestAnnotationService:
|
||||
AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id)
|
||||
|
||||
def test_insert_app_annotation_directly_with_setting_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful direct insertion of app annotation with annotation setting enabled.
|
||||
@ -1108,7 +1122,6 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotation setting first
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.model import AppAnnotationSetting
|
||||
|
||||
@ -1120,8 +1133,8 @@ class TestAnnotationService:
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(collection_binding)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
@ -1131,8 +1144,8 @@ class TestAnnotationService:
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(annotation_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup annotation data
|
||||
annotation_args = {
|
||||
@ -1161,7 +1174,7 @@ class TestAnnotationService:
|
||||
assert call_args[4] == collection_binding.id # collection_binding_id
|
||||
|
||||
def test_update_app_annotation_directly_with_setting_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful direct update of app annotation with annotation setting enabled.
|
||||
@ -1170,7 +1183,6 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotation setting first
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.model import AppAnnotationSetting
|
||||
|
||||
@ -1182,8 +1194,8 @@ class TestAnnotationService:
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(collection_binding)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
@ -1193,8 +1205,8 @@ class TestAnnotationService:
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(annotation_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# First, create an annotation
|
||||
original_args = {
|
||||
@ -1234,7 +1246,7 @@ class TestAnnotationService:
|
||||
assert call_args[4] == collection_binding.id # collection_binding_id
|
||||
|
||||
def test_delete_app_annotation_with_setting_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful deletion of app annotation with annotation setting enabled.
|
||||
@ -1243,7 +1255,6 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotation setting first
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.model import AppAnnotationSetting
|
||||
|
||||
@ -1255,8 +1266,8 @@ class TestAnnotationService:
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(collection_binding)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
@ -1267,8 +1278,8 @@ class TestAnnotationService:
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(annotation_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create an annotation first
|
||||
annotation_args = {
|
||||
@ -1285,7 +1296,9 @@ class TestAnnotationService:
|
||||
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
|
||||
|
||||
# Verify annotation was deleted
|
||||
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
deleted_annotation = (
|
||||
db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
|
||||
)
|
||||
assert deleted_annotation is None
|
||||
|
||||
# Verify delete_annotation_index_task was called
|
||||
@ -1297,7 +1310,7 @@ class TestAnnotationService:
|
||||
assert call_args[3] == collection_binding.id # collection_binding_id
|
||||
|
||||
def test_up_insert_app_annotation_from_message_with_setting_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creating annotation from message with annotation setting enabled.
|
||||
@ -1306,7 +1319,6 @@ class TestAnnotationService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create annotation setting first
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetCollectionBinding
|
||||
from models.model import AppAnnotationSetting
|
||||
|
||||
@ -1318,8 +1330,8 @@ class TestAnnotationService:
|
||||
collection_name=f"annotation_collection_{fake.uuid4()}",
|
||||
)
|
||||
collection_binding.id = str(fake.uuid4())
|
||||
db.session.add(collection_binding)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(collection_binding)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create annotation setting
|
||||
annotation_setting = AppAnnotationSetting(
|
||||
@ -1329,12 +1341,12 @@ class TestAnnotationService:
|
||||
created_user_id=account.id,
|
||||
updated_user_id=account.id,
|
||||
)
|
||||
db.session.add(annotation_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(annotation_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create a conversation and message first
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Setup annotation data with message_id
|
||||
annotation_args = {
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.api_based_extension import APIBasedExtension
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -31,7 +32,7 @@ class TestAPIBasedExtensionService:
|
||||
"requestor_instance": mock_requestor_instance,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -61,7 +62,7 @@ class TestAPIBasedExtensionService:
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful saving of API-based extension.
|
||||
"""
|
||||
@ -90,15 +91,16 @@ class TestAPIBasedExtensionService:
|
||||
assert saved_extension.created_at is not None
|
||||
|
||||
# Verify extension was saved to database
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(saved_extension)
|
||||
db_session_with_containers.refresh(saved_extension)
|
||||
assert saved_extension.id is not None
|
||||
|
||||
# Verify ping connection was called
|
||||
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
|
||||
|
||||
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_validation_errors(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation errors when saving extension with invalid data.
|
||||
"""
|
||||
@ -132,7 +134,9 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_all_by_tenant_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of all extensions by tenant ID.
|
||||
"""
|
||||
@ -169,7 +173,7 @@ class TestAPIBasedExtensionService:
|
||||
# Verify descending order (newer first)
|
||||
assert extension.created_at <= extension_list[i - 1].created_at
|
||||
|
||||
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of extension by tenant ID and extension ID.
|
||||
"""
|
||||
@ -200,7 +204,9 @@ class TestAPIBasedExtensionService:
|
||||
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
|
||||
assert retrieved_extension.created_at is not None
|
||||
|
||||
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_with_tenant_id_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of extension when extension is not found.
|
||||
"""
|
||||
@ -214,7 +220,7 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
|
||||
|
||||
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of extension.
|
||||
"""
|
||||
@ -238,12 +244,15 @@ class TestAPIBasedExtensionService:
|
||||
APIBasedExtensionService.delete(created_extension)
|
||||
|
||||
# Verify extension was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
|
||||
deleted_extension = (
|
||||
db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
|
||||
)
|
||||
assert deleted_extension is None
|
||||
|
||||
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_duplicate_name(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when saving extension with duplicate name.
|
||||
"""
|
||||
@ -272,7 +281,9 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(extension_data2)
|
||||
|
||||
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_update_existing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of existing extension.
|
||||
"""
|
||||
@ -329,7 +340,9 @@ class TestAPIBasedExtensionService:
|
||||
assert retrieved_extension.api_endpoint == new_endpoint
|
||||
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
|
||||
|
||||
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_connection_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test connection error when saving extension with invalid endpoint.
|
||||
"""
|
||||
@ -356,7 +369,7 @@ class TestAPIBasedExtensionService:
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_invalid_api_key_length(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when saving extension with API key that is too short.
|
||||
@ -378,7 +391,7 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test validation errors when saving extension with empty required fields.
|
||||
"""
|
||||
@ -412,7 +425,9 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_all_by_tenant_id_empty_list(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of extensions when no extensions exist for tenant.
|
||||
"""
|
||||
@ -428,7 +443,9 @@ class TestAPIBasedExtensionService:
|
||||
assert len(extension_list) == 0
|
||||
assert extension_list == []
|
||||
|
||||
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_invalid_ping_response(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when ping response is invalid.
|
||||
"""
|
||||
@ -452,7 +469,9 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_extension_missing_ping_result(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation error when ping response is missing result field.
|
||||
"""
|
||||
@ -476,7 +495,9 @@ class TestAPIBasedExtensionService:
|
||||
with pytest.raises(ValueError, match="{'status': 'ok'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
|
||||
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_with_tenant_id_wrong_tenant(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of extension when tenant ID doesn't match.
|
||||
"""
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models.model import EndUser
|
||||
@ -118,7 +119,9 @@ class TestAppGenerateService:
|
||||
"global_dify_config": mock_global_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"):
|
||||
def _create_test_app_and_account(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat"
|
||||
):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -169,7 +172,7 @@ class TestAppGenerateService:
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app):
|
||||
def _create_test_workflow(self, db_session_with_containers: Session, app):
|
||||
"""
|
||||
Helper method to create a test workflow for testing.
|
||||
|
||||
@ -191,14 +194,14 @@ class TestAppGenerateService:
|
||||
status="published",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return workflow
|
||||
|
||||
def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_completion_mode_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation for completion mode app.
|
||||
"""
|
||||
@ -226,7 +229,7 @@ class TestAppGenerateService:
|
||||
mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful generation for chat mode app.
|
||||
"""
|
||||
@ -250,7 +253,9 @@ class TestAppGenerateService:
|
||||
mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_agent_chat_mode_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation for agent chat mode app.
|
||||
"""
|
||||
@ -274,7 +279,9 @@ class TestAppGenerateService:
|
||||
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
|
||||
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_advanced_chat_mode_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation for advanced chat mode app.
|
||||
"""
|
||||
@ -300,7 +307,9 @@ class TestAppGenerateService:
|
||||
"advanced_chat_generator"
|
||||
].return_value.convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_workflow_mode_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation for workflow mode app.
|
||||
"""
|
||||
@ -324,7 +333,9 @@ class TestAppGenerateService:
|
||||
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once()
|
||||
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_specific_workflow_id(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with a specific workflow ID.
|
||||
"""
|
||||
@ -355,7 +366,9 @@ class TestAppGenerateService:
|
||||
"workflow_service"
|
||||
].return_value.get_published_workflow_by_id.assert_called_once()
|
||||
|
||||
def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_debugger_invoke_from(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with debugger invoke from.
|
||||
"""
|
||||
@ -378,7 +391,9 @@ class TestAppGenerateService:
|
||||
# Verify draft workflow was fetched for debugger
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once()
|
||||
|
||||
def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_non_streaming_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with non-streaming mode.
|
||||
"""
|
||||
@ -401,7 +416,7 @@ class TestAppGenerateService:
|
||||
# Verify rate limit exit was called for non-streaming mode
|
||||
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
|
||||
|
||||
def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with EndUser instead of Account.
|
||||
"""
|
||||
@ -421,10 +436,8 @@ class TestAppGenerateService:
|
||||
session_id=fake.uuid4(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup test arguments
|
||||
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
|
||||
@ -438,7 +451,7 @@ class TestAppGenerateService:
|
||||
assert result == ["test_response"]
|
||||
|
||||
def test_generate_with_billing_enabled_sandbox_plan(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with billing enabled and sandbox plan.
|
||||
@ -466,7 +479,9 @@ class TestAppGenerateService:
|
||||
# Verify billing service was called to consume quota
|
||||
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
|
||||
|
||||
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_invalid_app_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with invalid app mode.
|
||||
"""
|
||||
@ -491,7 +506,7 @@ class TestAppGenerateService:
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_id_format_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with invalid workflow ID format.
|
||||
@ -518,7 +533,7 @@ class TestAppGenerateService:
|
||||
assert "Invalid workflow_id format" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_not_found_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when workflow is not found.
|
||||
@ -552,7 +567,7 @@ class TestAppGenerateService:
|
||||
assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_not_initialized_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when workflow is not initialized for debugger.
|
||||
@ -578,7 +593,7 @@ class TestAppGenerateService:
|
||||
assert "Workflow not initialized" in str(exc_info.value)
|
||||
|
||||
def test_generate_with_workflow_not_published_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation when workflow is not published for non-debugger.
|
||||
@ -604,7 +619,7 @@ class TestAppGenerateService:
|
||||
assert "Workflow not published" in str(exc_info.value)
|
||||
|
||||
def test_generate_single_iteration_advanced_chat_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single iteration generation for advanced chat mode.
|
||||
@ -631,7 +646,7 @@ class TestAppGenerateService:
|
||||
].return_value.single_iteration_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_iteration_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single iteration generation for workflow mode.
|
||||
@ -658,7 +673,7 @@ class TestAppGenerateService:
|
||||
].return_value.single_iteration_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_iteration_invalid_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test single iteration generation with invalid app mode.
|
||||
@ -681,7 +696,7 @@ class TestAppGenerateService:
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_single_loop_advanced_chat_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single loop generation for advanced chat mode.
|
||||
@ -708,7 +723,7 @@ class TestAppGenerateService:
|
||||
].return_value.single_loop_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_loop_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful single loop generation for workflow mode.
|
||||
@ -732,7 +747,9 @@ class TestAppGenerateService:
|
||||
# Verify workflow generator was called
|
||||
mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once()
|
||||
|
||||
def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_single_loop_invalid_mode(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test single loop generation with invalid app mode.
|
||||
"""
|
||||
@ -753,7 +770,9 @@ class TestAppGenerateService:
|
||||
# Verify error message
|
||||
assert "Invalid app mode" in str(exc_info.value)
|
||||
|
||||
def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_more_like_this_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful more like this generation.
|
||||
"""
|
||||
@ -778,7 +797,7 @@ class TestAppGenerateService:
|
||||
].return_value.generate_more_like_this.assert_called_once()
|
||||
|
||||
def test_generate_more_like_this_with_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test more like this generation with EndUser.
|
||||
@ -799,10 +818,8 @@ class TestAppGenerateService:
|
||||
session_id=fake.uuid4(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
message_id = fake.uuid4()
|
||||
|
||||
@ -815,7 +832,7 @@ class TestAppGenerateService:
|
||||
assert result == ["more_like_this_response"]
|
||||
|
||||
def test_get_max_active_requests_with_app_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting max active requests with app-specific limit.
|
||||
@ -835,7 +852,7 @@ class TestAppGenerateService:
|
||||
assert result == 10
|
||||
|
||||
def test_get_max_active_requests_with_config_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting max active requests with config limit being smaller.
|
||||
@ -856,7 +873,7 @@ class TestAppGenerateService:
|
||||
assert result <= 100
|
||||
|
||||
def test_get_max_active_requests_with_zero_limits(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting max active requests with zero limits (infinite).
|
||||
@ -875,7 +892,9 @@ class TestAppGenerateService:
|
||||
# Verify the result (should return config limit when app limit is 0)
|
||||
assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS
|
||||
|
||||
def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_exception_cleanup(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that rate limit exit is called when an exception occurs.
|
||||
"""
|
||||
@ -904,7 +923,9 @@ class TestAppGenerateService:
|
||||
# Verify rate limit exit was called for cleanup
|
||||
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
|
||||
|
||||
def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_agent_mode_detection(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with agent mode detection based on app configuration.
|
||||
"""
|
||||
@ -932,7 +953,7 @@ class TestAppGenerateService:
|
||||
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
|
||||
|
||||
def test_generate_with_different_invoke_from_values(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test generation with different invoke from values.
|
||||
@ -962,7 +983,7 @@ class TestAppGenerateService:
|
||||
# Verify the result
|
||||
assert result == ["test_response"]
|
||||
|
||||
def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test generation with complex arguments including files and external trace ID.
|
||||
"""
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants.model_template import default_app_templates
|
||||
from models import Account
|
||||
@ -44,7 +45,7 @@ class TestAppService:
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app creation with basic parameters.
|
||||
"""
|
||||
@ -98,7 +99,9 @@ class TestAppService:
|
||||
assert app.is_public is False
|
||||
assert app.is_universal is False
|
||||
|
||||
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_app_with_different_modes(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app creation with different app modes.
|
||||
"""
|
||||
@ -141,7 +144,7 @@ class TestAppService:
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.created_by == account.id
|
||||
|
||||
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app retrieval.
|
||||
"""
|
||||
@ -189,7 +192,7 @@ class TestAppService:
|
||||
assert retrieved_app.tenant_id == created_app.tenant_id
|
||||
assert retrieved_app.created_by == created_app.created_by
|
||||
|
||||
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful paginated app list retrieval.
|
||||
"""
|
||||
@ -243,7 +246,9 @@ class TestAppService:
|
||||
assert app.tenant_id == tenant.id
|
||||
assert app.mode == "chat"
|
||||
|
||||
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_paginate_apps_with_filters(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test paginated app list with various filters.
|
||||
"""
|
||||
@ -316,7 +321,9 @@ class TestAppService:
|
||||
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
|
||||
assert len(my_apps.items) == 1
|
||||
|
||||
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_paginate_apps_with_tag_filters(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test paginated app list with tag filters.
|
||||
"""
|
||||
@ -386,7 +393,7 @@ class TestAppService:
|
||||
# Should return None when no apps match tag filter
|
||||
assert paginated_apps is None
|
||||
|
||||
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app update with all fields.
|
||||
"""
|
||||
@ -455,7 +462,7 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app name update.
|
||||
"""
|
||||
@ -508,7 +515,7 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app icon update.
|
||||
"""
|
||||
@ -565,7 +572,9 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_app_site_status_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful app site status update.
|
||||
"""
|
||||
@ -623,7 +632,9 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_app_api_status_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful app API status update.
|
||||
"""
|
||||
@ -681,7 +692,9 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_app_site_status_no_change(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app site status update when status doesn't change.
|
||||
"""
|
||||
@ -732,7 +745,7 @@ class TestAppService:
|
||||
assert updated_app.tenant_id == app.tenant_id
|
||||
assert updated_app.created_by == app.created_by
|
||||
|
||||
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app deletion.
|
||||
"""
|
||||
@ -778,12 +791,13 @@ class TestAppService:
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_app_with_related_data(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app deletion with related data cleanup.
|
||||
"""
|
||||
@ -839,12 +853,11 @@ class TestAppService:
|
||||
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
|
||||
|
||||
# Verify app was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_app = db.session.query(App).filter_by(id=app_id).first()
|
||||
deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first()
|
||||
assert deleted_app is None
|
||||
|
||||
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app metadata retrieval.
|
||||
"""
|
||||
@ -883,7 +896,7 @@ class TestAppService:
|
||||
assert "tool_icons" in app_meta
|
||||
# Note: get_app_meta currently only returns tool_icons
|
||||
|
||||
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app code retrieval by app ID.
|
||||
"""
|
||||
@ -923,7 +936,7 @@ class TestAppService:
|
||||
assert app_code is not None
|
||||
assert len(app_code) > 0
|
||||
|
||||
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful app ID retrieval by app code.
|
||||
"""
|
||||
@ -963,10 +976,9 @@ class TestAppService:
|
||||
site.status = "normal"
|
||||
site.default_language = "en-US"
|
||||
site.customize_token_strategy = "uuid"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(site)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Get app ID by code
|
||||
app_id = AppService.get_app_id_by_code(site.code)
|
||||
@ -974,7 +986,7 @@ class TestAppService:
|
||||
# Verify app ID was retrieved correctly
|
||||
assert app_id == app.id
|
||||
|
||||
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test app creation with invalid mode.
|
||||
"""
|
||||
@ -1010,7 +1022,7 @@ class TestAppService:
|
||||
app_service.create_app(tenant.id, app_args, account)
|
||||
|
||||
def test_get_apps_with_special_characters_in_name(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test app retrieval with special characters in name search to verify SQL injection prevention.
|
||||
|
||||
@ -9,10 +9,10 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
|
||||
from services.dataset_service import DatasetService
|
||||
@ -25,7 +25,9 @@ class DatasetServiceIntegrationDataFactory:
|
||||
"""Factory for creating real database entities used by integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""Create an account and tenant, then bind the account as current tenant member."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
@ -34,8 +36,8 @@ class DatasetServiceIntegrationDataFactory:
|
||||
status="active",
|
||||
)
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add_all([account, tenant])
|
||||
db.session.flush()
|
||||
db_session_with_containers.add_all([account, tenant])
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
@ -43,8 +45,8 @@ class DatasetServiceIntegrationDataFactory:
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Keep tenant context on the in-memory user without opening a separate session.
|
||||
account.role = role
|
||||
@ -53,6 +55,7 @@ class DatasetServiceIntegrationDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
name: str = "Test Dataset",
|
||||
@ -82,12 +85,14 @@ class DatasetServiceIntegrationDataFactory:
|
||||
collection_binding_id=collection_binding_id,
|
||||
chunk_structure=chunk_structure,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document:
|
||||
def create_document(
|
||||
db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt"
|
||||
) -> Document:
|
||||
"""Create a document row belonging to the given dataset."""
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
@ -102,8 +107,8 @@ class DatasetServiceIntegrationDataFactory:
|
||||
indexing_status="completed",
|
||||
doc_form="text_model",
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.flush()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
@ -118,10 +123,10 @@ class DatasetServiceIntegrationDataFactory:
|
||||
class TestDatasetServiceCreateDataset:
|
||||
"""Integration coverage for DatasetService.create_empty_dataset."""
|
||||
|
||||
def test_create_internal_dataset_basic_success(self, db_session_with_containers):
|
||||
def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session):
|
||||
"""Create a basic internal dataset with minimal configuration."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
@ -133,17 +138,17 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
created_dataset = db.session.get(Dataset, result.id)
|
||||
created_dataset = db_session_with_containers.get(Dataset, result.id)
|
||||
assert created_dataset is not None
|
||||
assert created_dataset.provider == "vendor"
|
||||
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||
assert created_dataset.embedding_model_provider is None
|
||||
assert created_dataset.embedding_model is None
|
||||
|
||||
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers):
|
||||
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session):
|
||||
"""Create an internal dataset with economy indexing and no embedding model."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
@ -155,15 +160,15 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "economy"
|
||||
assert result.embedding_model_provider is None
|
||||
assert result.embedding_model is None
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers):
|
||||
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session):
|
||||
"""Create a high-quality dataset and persist embedding model settings."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
|
||||
|
||||
# Act
|
||||
@ -179,7 +184,7 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
@ -188,11 +193,12 @@ class TestDatasetServiceCreateDataset:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
|
||||
def test_create_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session):
|
||||
"""Raise duplicate-name error when the same tenant already has the name."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Duplicate Dataset",
|
||||
@ -209,10 +215,10 @@ class TestDatasetServiceCreateDataset:
|
||||
account=account,
|
||||
)
|
||||
|
||||
def test_create_external_dataset_success(self, db_session_with_containers):
|
||||
def test_create_external_dataset_success(self, db_session_with_containers: Session):
|
||||
"""Create an external dataset and persist external knowledge binding."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
external_knowledge_api_id = str(uuid4())
|
||||
external_knowledge_id = "knowledge-123"
|
||||
|
||||
@ -231,16 +237,16 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
|
||||
binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
|
||||
assert result.provider == "external"
|
||||
assert binding is not None
|
||||
assert binding.external_knowledge_id == external_knowledge_id
|
||||
assert binding.external_knowledge_api_id == external_knowledge_api_id
|
||||
|
||||
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers):
|
||||
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session):
|
||||
"""Create a high-quality dataset with retrieval/reranking settings."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@ -271,14 +277,16 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.retrieval_model == retrieval_model.model_dump()
|
||||
mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0")
|
||||
|
||||
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(self, db_session_with_containers):
|
||||
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(
|
||||
self, db_session_with_containers: Session
|
||||
):
|
||||
"""Create high-quality dataset with explicitly configured embedding model."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
embedding_provider = "openai"
|
||||
embedding_model_name = "text-embedding-3-small"
|
||||
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model(
|
||||
@ -303,7 +311,7 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_provider
|
||||
assert result.embedding_model == embedding_model_name
|
||||
@ -315,10 +323,10 @@ class TestDatasetServiceCreateDataset:
|
||||
model=embedding_model_name,
|
||||
)
|
||||
|
||||
def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers):
|
||||
def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session):
|
||||
"""Persist retrieval model settings when creating an internal dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH,
|
||||
reranking_enable=False,
|
||||
@ -338,13 +346,13 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.retrieval_model == retrieval_model.model_dump()
|
||||
|
||||
def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers):
|
||||
def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session):
|
||||
"""Persist canonical custom permission when creating an internal dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
|
||||
# Act
|
||||
result = DatasetService.create_empty_dataset(
|
||||
@ -357,13 +365,13 @@ class TestDatasetServiceCreateDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.permission == DatasetPermissionEnum.ALL_TEAM
|
||||
|
||||
def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers):
|
||||
def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session):
|
||||
"""Raise error when external API template does not exist."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
external_knowledge_api_id = str(uuid4())
|
||||
|
||||
# Act / Assert
|
||||
@ -381,10 +389,10 @@ class TestDatasetServiceCreateDataset:
|
||||
external_knowledge_id="knowledge-123",
|
||||
)
|
||||
|
||||
def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
|
||||
def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session):
|
||||
"""Raise error when external knowledge id is missing for external dataset creation."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
external_knowledge_api_id = str(uuid4())
|
||||
|
||||
# Act / Assert
|
||||
@ -406,10 +414,10 @@ class TestDatasetServiceCreateDataset:
|
||||
class TestDatasetServiceCreateRagPipelineDataset:
|
||||
"""Integration coverage for DatasetService.create_empty_rag_pipeline_dataset."""
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers):
|
||||
def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session):
|
||||
"""Create rag-pipeline dataset and pipeline rows when a name is provided."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="RAG Pipeline Dataset",
|
||||
@ -425,8 +433,8 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
created_dataset = db.session.get(Dataset, result.id)
|
||||
created_pipeline = db.session.get(Pipeline, result.pipeline_id)
|
||||
created_dataset = db_session_with_containers.get(Dataset, result.id)
|
||||
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
|
||||
assert created_dataset is not None
|
||||
assert created_dataset.name == entity.name
|
||||
assert created_dataset.runtime_mode == "rag_pipeline"
|
||||
@ -436,10 +444,10 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
assert created_pipeline.name == entity.name
|
||||
assert created_pipeline.created_by == account.id
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers):
|
||||
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session):
|
||||
"""Create rag-pipeline dataset with generated incremental name when input name is empty."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
generated_name = "Untitled 1"
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
@ -460,25 +468,26 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
created_pipeline = db.session.get(Pipeline, result.pipeline_id)
|
||||
db_session_with_containers.refresh(result)
|
||||
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
|
||||
assert result.name == generated_name
|
||||
assert created_pipeline is not None
|
||||
assert created_pipeline.name == generated_name
|
||||
mock_generate_name.assert_called_once()
|
||||
|
||||
def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session):
|
||||
"""Raise duplicate-name error when rag-pipeline dataset name already exists."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
duplicate_name = "Duplicate RAG Dataset"
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name=duplicate_name,
|
||||
indexing_technique=None,
|
||||
)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name=duplicate_name,
|
||||
@ -496,10 +505,10 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity
|
||||
)
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers):
|
||||
def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session):
|
||||
"""Persist canonical custom permission for rag-pipeline dataset creation."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
|
||||
entity = RagPipelineDatasetCreateEntity(
|
||||
name="Custom Permission RAG Dataset",
|
||||
@ -515,13 +524,13 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.permission == DatasetPermissionEnum.ALL_TEAM
|
||||
|
||||
def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers):
|
||||
def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session):
|
||||
"""Persist icon metadata when creating rag-pipeline dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
icon_info = IconInfo(
|
||||
icon="📚",
|
||||
icon_background="#E8F5E9",
|
||||
@ -542,23 +551,25 @@ class TestDatasetServiceCreateRagPipelineDataset:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.icon_info == icon_info.model_dump()
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
"""Integration coverage for SQL-backed update and delete behavior."""
|
||||
|
||||
def test_update_dataset_duplicate_name_error(self, db_session_with_containers):
|
||||
def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session):
|
||||
"""Reject update when target name already exists within the same tenant."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
source_dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Source Dataset",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Existing Dataset",
|
||||
@ -568,17 +579,20 @@ class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
with pytest.raises(ValueError, match="Dataset name already exists"):
|
||||
DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account)
|
||||
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
|
||||
def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session):
|
||||
"""Delete a dataset that already has documents."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
chunk_structure="text_model",
|
||||
)
|
||||
DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id)
|
||||
DatasetServiceIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, created_by=account.id
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
|
||||
@ -586,14 +600,15 @@ class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers):
|
||||
def test_delete_empty_dataset_success(self, db_session_with_containers: Session):
|
||||
"""Delete a dataset that has no documents and no indexing technique."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique=None,
|
||||
@ -606,14 +621,15 @@ class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
|
||||
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session):
|
||||
"""Delete dataset when indexing_technique is None but doc_form path still exists."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique=None,
|
||||
@ -626,17 +642,17 @@ class TestDatasetServiceUpdateAndDeleteDataset:
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
assert db.session.get(Dataset, dataset.id) is None
|
||||
assert db_session_with_containers.get(Dataset, dataset.id) is None
|
||||
dataset_deleted_signal.send.assert_called_once_with(dataset)
|
||||
|
||||
|
||||
class TestDatasetServiceRetrievalConfiguration:
|
||||
"""Integration coverage for retrieval configuration persistence."""
|
||||
|
||||
def test_get_dataset_retrieval_configuration(self, db_session_with_containers):
|
||||
def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session):
|
||||
"""Return retrieval configuration that is persisted in SQL."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 5,
|
||||
@ -644,6 +660,7 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
"reranking_enable": True,
|
||||
}
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
retrieval_model=retrieval_model,
|
||||
@ -658,11 +675,12 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
assert result.retrieval_model["search_method"] == "semantic_search"
|
||||
assert result.retrieval_model["top_k"] == 5
|
||||
|
||||
def test_update_dataset_retrieval_configuration(self, db_session_with_containers):
|
||||
def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session):
|
||||
"""Persist retrieval configuration updates through DatasetService.update_dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
@ -684,6 +702,6 @@ class TestDatasetServiceRetrievalConfiguration:
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, account)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert result.id == dataset.id
|
||||
assert dataset.retrieval_model == update_data["retrieval_model"]
|
||||
|
||||
@ -11,8 +11,8 @@ from unittest.mock import call, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
from services.dataset_service import DocumentService
|
||||
from services.errors.document import DocumentIndexingError
|
||||
@ -32,6 +32,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers: Session,
|
||||
dataset_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
name: str = "Test Dataset",
|
||||
@ -47,12 +48,13 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
if dataset_id:
|
||||
dataset.id = dataset_id
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(
|
||||
db_session_with_containers: Session,
|
||||
dataset: Dataset,
|
||||
document_id: str | None = None,
|
||||
name: str = "test_document.pdf",
|
||||
@ -89,13 +91,14 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
for key, value in kwargs.items():
|
||||
setattr(document, key, value)
|
||||
|
||||
db.session.add(document)
|
||||
db_session_with_containers.add(document)
|
||||
if commit:
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_multiple_documents(
|
||||
db_session_with_containers: Session,
|
||||
dataset: Dataset,
|
||||
document_ids: list[str],
|
||||
enabled: bool = True,
|
||||
@ -106,6 +109,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
documents: list[Document] = []
|
||||
for index, doc_id in enumerate(document_ids, start=1):
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
document_id=doc_id,
|
||||
name=f"document_{doc_id}.pdf",
|
||||
@ -116,7 +120,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
|
||||
commit=False,
|
||||
)
|
||||
documents.append(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
@ -173,13 +177,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
assert document.archived_at is None
|
||||
assert document.archived_by is None
|
||||
|
||||
def test_batch_update_enable_documents_success(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Enable disabled documents and trigger indexing side effects."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document_ids = [str(uuid4()), str(uuid4())]
|
||||
disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
document_ids=document_ids,
|
||||
enabled=False,
|
||||
@ -192,7 +197,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
|
||||
# Assert
|
||||
for document in disabled_docs:
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_enabled(document, FIXED_TIME)
|
||||
|
||||
expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
|
||||
@ -203,13 +208,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls)
|
||||
|
||||
def test_batch_update_enable_already_enabled_document_skipped(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Skip enable operation for already-enabled documents."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True)
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=True
|
||||
)
|
||||
|
||||
# Act
|
||||
DocumentService.batch_update_document_status(
|
||||
@ -220,18 +227,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is True
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_disable_documents_success(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Disable completed documents and trigger remove-index tasks."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document_ids = [str(uuid4()), str(uuid4())]
|
||||
enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
document_ids=document_ids,
|
||||
enabled=True,
|
||||
@ -248,7 +256,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
|
||||
# Assert
|
||||
for document in enabled_docs:
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_disabled(document, user.id, FIXED_TIME)
|
||||
|
||||
expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
|
||||
@ -259,13 +267,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls)
|
||||
|
||||
def test_batch_update_disable_already_disabled_document_skipped(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Skip disable operation for already-disabled documents."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=False,
|
||||
indexing_status="completed",
|
||||
@ -281,17 +290,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(disabled_doc)
|
||||
db_session_with_containers.refresh(disabled_doc)
|
||||
assert disabled_doc.enabled is False
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["remove_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_disable_non_completed_document_error(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_disable_non_completed_document_error(
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Raise error when disabling a non-completed document."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
indexing_status="indexing",
|
||||
@ -307,13 +319,13 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
user=user,
|
||||
)
|
||||
|
||||
def test_batch_update_archive_documents_success(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Archive enabled documents and trigger remove-index task."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
dataset=dataset, enabled=True, archived=False
|
||||
db_session_with_containers, dataset=dataset, enabled=True, archived=False
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -325,21 +337,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_archived(document, user.id, FIXED_TIME)
|
||||
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
|
||||
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1)
|
||||
patched_dependencies["remove_task"].delay.assert_called_once_with(document.id)
|
||||
|
||||
def test_batch_update_archive_already_archived_document_skipped(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Skip archive operation for already-archived documents."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
dataset=dataset, enabled=True, archived=True
|
||||
db_session_with_containers, dataset=dataset, enabled=True, archived=True
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -351,20 +363,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.archived is True
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["remove_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_archive_disabled_document_no_index_removal(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Archive disabled document without index-removal side effects."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
dataset=dataset, enabled=False, archived=False
|
||||
db_session_with_containers, dataset=dataset, enabled=False, archived=False
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -376,18 +388,18 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_archived(document, user.id, FIXED_TIME)
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["remove_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_unarchive_documents_success(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Unarchive enabled documents and trigger add-index task."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
dataset=dataset, enabled=True, archived=True
|
||||
db_session_with_containers, dataset=dataset, enabled=True, archived=True
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -399,7 +411,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_unarchived(document)
|
||||
assert document.updated_at == FIXED_TIME
|
||||
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
|
||||
@ -407,14 +419,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
patched_dependencies["add_task"].delay.assert_called_once_with(document.id)
|
||||
|
||||
def test_batch_update_unarchive_already_unarchived_document_skipped(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Skip unarchive operation for already-unarchived documents."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
dataset=dataset, enabled=True, archived=False
|
||||
db_session_with_containers, dataset=dataset, enabled=True, archived=False
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -426,20 +438,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.archived is False
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_unarchive_disabled_document_no_index_addition(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Unarchive disabled document without index-add side effects."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
dataset=dataset, enabled=False, archived=True
|
||||
db_session_with_containers, dataset=dataset, enabled=False, archived=True
|
||||
)
|
||||
|
||||
# Act
|
||||
@ -451,20 +463,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_unarchived(document)
|
||||
assert document.updated_at == FIXED_TIME
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_document_indexing_error_redis_cache_hit(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Raise DocumentIndexingError when redis indicates active indexing."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
name="test_document.pdf",
|
||||
enabled=True,
|
||||
@ -483,12 +496,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
assert "test_document.pdf" in str(exc_info.value)
|
||||
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
|
||||
|
||||
def test_batch_update_async_task_error_handling(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Persist DB update, then propagate async task error."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False)
|
||||
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=False
|
||||
)
|
||||
patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error")
|
||||
|
||||
# Act / Assert
|
||||
@ -500,14 +515,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
user=user,
|
||||
)
|
||||
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_enabled(document, FIXED_TIME)
|
||||
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1)
|
||||
|
||||
def test_batch_update_empty_document_list(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Return early when document_ids is empty."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
|
||||
# Act
|
||||
@ -520,10 +535,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
patched_dependencies["redis_client"].get.assert_not_called()
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
|
||||
def test_batch_update_document_not_found_skipped(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies):
|
||||
"""Skip IDs that do not map to existing dataset documents."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
missing_document_id = str(uuid4())
|
||||
|
||||
@ -540,18 +555,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
patched_dependencies["redis_client"].setex.assert_not_called()
|
||||
patched_dependencies["add_task"].delay.assert_not_called()
|
||||
|
||||
def test_batch_update_mixed_document_states_and_actions(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_mixed_document_states_and_actions(
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Process only the applicable document in a mixed-state enable batch."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False)
|
||||
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=False
|
||||
)
|
||||
enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
position=2,
|
||||
)
|
||||
archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
archived=True,
|
||||
@ -568,9 +589,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(disabled_doc)
|
||||
db.session.refresh(enabled_doc)
|
||||
db.session.refresh(archived_doc)
|
||||
db_session_with_containers.refresh(disabled_doc)
|
||||
db_session_with_containers.refresh(enabled_doc)
|
||||
db_session_with_containers.refresh(archived_doc)
|
||||
self._assert_document_enabled(disabled_doc, FIXED_TIME)
|
||||
assert enabled_doc.enabled is True
|
||||
assert archived_doc.enabled is True
|
||||
@ -582,13 +603,16 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id)
|
||||
|
||||
def test_batch_update_large_document_list_performance(self, db_session_with_containers, patched_dependencies):
|
||||
def test_batch_update_large_document_list_performance(
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Handle large document lists with consistent updates and side effects."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
document_ids = [str(uuid4()) for _ in range(100)]
|
||||
documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
document_ids=document_ids,
|
||||
enabled=False,
|
||||
@ -604,7 +628,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
|
||||
# Assert
|
||||
for document in documents:
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
self._assert_document_enabled(document, FIXED_TIME)
|
||||
|
||||
assert patched_dependencies["redis_client"].setex.call_count == len(document_ids)
|
||||
@ -616,17 +640,26 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)
|
||||
|
||||
def test_batch_update_mixed_document_states_complex_scenario(
|
||||
self, db_session_with_containers, patched_dependencies
|
||||
self, db_session_with_containers: Session, patched_dependencies
|
||||
):
|
||||
"""Process a complex mixed-state batch and update only eligible records."""
|
||||
# Arrange
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
|
||||
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
|
||||
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
|
||||
doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False)
|
||||
doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=2)
|
||||
doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=3)
|
||||
doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=4)
|
||||
doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=False
|
||||
)
|
||||
doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=True, position=2
|
||||
)
|
||||
doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=True, position=3
|
||||
)
|
||||
doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers, dataset=dataset, enabled=True, position=4
|
||||
)
|
||||
doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document(
|
||||
db_session_with_containers,
|
||||
dataset=dataset,
|
||||
enabled=True,
|
||||
archived=True,
|
||||
@ -645,11 +678,11 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
|
||||
)
|
||||
|
||||
# Assert
|
||||
db.session.refresh(doc1)
|
||||
db.session.refresh(doc2)
|
||||
db.session.refresh(doc3)
|
||||
db.session.refresh(doc4)
|
||||
db.session.refresh(doc5)
|
||||
db_session_with_containers.refresh(doc1)
|
||||
db_session_with_containers.refresh(doc2)
|
||||
db_session_with_containers.refresh(doc3)
|
||||
db_session_with_containers.refresh(doc4)
|
||||
db_session_with_containers.refresh(doc5)
|
||||
self._assert_document_enabled(doc1, FIXED_TIME)
|
||||
assert doc2.enabled is True
|
||||
assert doc3.enabled is True
|
||||
|
||||
@ -10,7 +10,8 @@ Tests the retrieval of document segments with pagination and filtering:
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
||||
from services.dataset_service import SegmentService
|
||||
@ -23,6 +24,7 @@ class SegmentServiceTestDataFactory:
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session,
|
||||
role: TenantAccountRole = TenantAccountRole.OWNER,
|
||||
tenant: Tenant | None = None,
|
||||
) -> tuple[Account, Tenant]:
|
||||
@ -33,13 +35,13 @@ class SegmentServiceTestDataFactory:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
if tenant is None:
|
||||
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
@ -47,14 +49,14 @@ class SegmentServiceTestDataFactory:
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(tenant_id: str, created_by: str) -> Dataset:
|
||||
def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset:
|
||||
"""Create a real dataset."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@ -67,12 +69,14 @@ class SegmentServiceTestDataFactory:
|
||||
provider="vendor",
|
||||
retrieval_model={"top_k": 2},
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document:
|
||||
def create_document(
|
||||
db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str
|
||||
) -> Document:
|
||||
"""Create a real document."""
|
||||
document = Document(
|
||||
tenant_id=tenant_id,
|
||||
@ -84,12 +88,13 @@ class SegmentServiceTestDataFactory:
|
||||
created_from="api",
|
||||
created_by=created_by,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
return document
|
||||
|
||||
@staticmethod
|
||||
def create_segment(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
document_id: str,
|
||||
@ -112,8 +117,8 @@ class SegmentServiceTestDataFactory:
|
||||
tokens=tokens,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
return segment
|
||||
|
||||
|
||||
@ -130,7 +135,7 @@ class TestSegmentServiceGetSegments:
|
||||
- Combined filters
|
||||
"""
|
||||
|
||||
def test_get_segments_basic_pagination(self, db_session_with_containers):
|
||||
def test_get_segments_basic_pagination(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test basic pagination functionality.
|
||||
|
||||
@ -140,11 +145,14 @@ class TestSegmentServiceGetSegments:
|
||||
- Returns segments and total count
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
segment1 = SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -153,6 +161,7 @@ class TestSegmentServiceGetSegments:
|
||||
content="First segment",
|
||||
)
|
||||
segment2 = SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -170,7 +179,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert items[0].id == segment1.id
|
||||
assert items[1].id == segment2.id
|
||||
|
||||
def test_get_segments_with_status_filter(self, db_session_with_containers):
|
||||
def test_get_segments_with_status_filter(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test filtering by status list.
|
||||
|
||||
@ -179,11 +188,14 @@ class TestSegmentServiceGetSegments:
|
||||
- Only segments with matching status are returned
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -192,6 +204,7 @@ class TestSegmentServiceGetSegments:
|
||||
status="completed",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -200,6 +213,7 @@ class TestSegmentServiceGetSegments:
|
||||
status="indexing",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -219,7 +233,7 @@ class TestSegmentServiceGetSegments:
|
||||
statuses = {item.status for item in items}
|
||||
assert statuses == {"completed", "indexing"}
|
||||
|
||||
def test_get_segments_with_empty_status_list(self, db_session_with_containers):
|
||||
def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test with empty status list.
|
||||
|
||||
@ -228,11 +242,14 @@ class TestSegmentServiceGetSegments:
|
||||
- No status filter is applied to avoid WHERE false condition
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -241,6 +258,7 @@ class TestSegmentServiceGetSegments:
|
||||
status="completed",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -256,7 +274,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
|
||||
def test_get_segments_with_keyword_search(self, db_session_with_containers):
|
||||
def test_get_segments_with_keyword_search(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test keyword search functionality.
|
||||
|
||||
@ -265,11 +283,14 @@ class TestSegmentServiceGetSegments:
|
||||
- Search pattern includes wildcards (%keyword%)
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -278,6 +299,7 @@ class TestSegmentServiceGetSegments:
|
||||
content="This contains search term in the middle",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -294,7 +316,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert total == 1
|
||||
assert "search term" in items[0].content
|
||||
|
||||
def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers):
|
||||
def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test ordering by position and id.
|
||||
|
||||
@ -304,12 +326,15 @@ class TestSegmentServiceGetSegments:
|
||||
- This prevents duplicate data across pages when positions are not unique
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
# Create segments with different positions
|
||||
seg_pos2 = SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -318,6 +343,7 @@ class TestSegmentServiceGetSegments:
|
||||
content="Position 2",
|
||||
)
|
||||
seg_pos1 = SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -326,6 +352,7 @@ class TestSegmentServiceGetSegments:
|
||||
content="Position 1",
|
||||
)
|
||||
seg_pos3 = SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -344,7 +371,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert items[1].id == seg_pos2.id
|
||||
assert items[2].id == seg_pos3.id
|
||||
|
||||
def test_get_segments_empty_results(self, db_session_with_containers):
|
||||
def test_get_segments_empty_results(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test when no segments match the criteria.
|
||||
|
||||
@ -353,7 +380,7 @@ class TestSegmentServiceGetSegments:
|
||||
- Total count is 0
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
non_existent_doc_id = str(uuid4())
|
||||
|
||||
# Act
|
||||
@ -363,7 +390,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert items == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_segments_combined_filters(self, db_session_with_containers):
|
||||
def test_get_segments_combined_filters(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test with multiple filters combined.
|
||||
|
||||
@ -372,12 +399,15 @@ class TestSegmentServiceGetSegments:
|
||||
- Status list and keyword search both applied
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
# Create segments with various statuses and content
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -387,6 +417,7 @@ class TestSegmentServiceGetSegments:
|
||||
content="This is important information",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -396,6 +427,7 @@ class TestSegmentServiceGetSegments:
|
||||
content="This is also important",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -421,7 +453,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert items[0].status == "completed"
|
||||
assert "important" in items[0].content
|
||||
|
||||
def test_get_segments_with_none_status_list(self, db_session_with_containers):
|
||||
def test_get_segments_with_none_status_list(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test with None status list.
|
||||
|
||||
@ -430,11 +462,14 @@ class TestSegmentServiceGetSegments:
|
||||
- No status filter is applied
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -443,6 +478,7 @@ class TestSegmentServiceGetSegments:
|
||||
status="completed",
|
||||
)
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
@ -462,7 +498,7 @@ class TestSegmentServiceGetSegments:
|
||||
assert len(items) == 2
|
||||
assert total == 2
|
||||
|
||||
def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers):
|
||||
def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that max_per_page is correctly set to 100.
|
||||
|
||||
@ -471,13 +507,16 @@ class TestSegmentServiceGetSegments:
|
||||
- This prevents excessive page sizes
|
||||
"""
|
||||
# Arrange
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
|
||||
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
|
||||
document = SegmentServiceTestDataFactory.create_document(
|
||||
db_session_with_containers, tenant.id, dataset.id, owner.id
|
||||
)
|
||||
|
||||
# Create 105 segments to exceed max_per_page of 100
|
||||
for i in range(105):
|
||||
SegmentServiceTestDataFactory.create_segment(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
|
||||
@ -13,7 +13,8 @@ This test suite covers:
|
||||
import json
|
||||
from uuid import uuid4
|
||||
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
AppDatasetJoin,
|
||||
@ -31,7 +32,9 @@ class DatasetRetrievalTestDataFactory:
|
||||
"""Factory class for creating database-backed test data for dataset retrieval integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]:
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""Create an account and tenant with the specified role."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
@ -43,8 +46,8 @@ class DatasetRetrievalTestDataFactory:
|
||||
name=f"tenant-{uuid4()}",
|
||||
status="normal",
|
||||
)
|
||||
db.session.add_all([account, tenant])
|
||||
db.session.flush()
|
||||
db_session_with_containers.add_all([account, tenant])
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
@ -52,14 +55,16 @@ class DatasetRetrievalTestDataFactory:
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account:
|
||||
def create_account_in_tenant(
|
||||
db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER
|
||||
) -> Account:
|
||||
"""Create an account and add it to an existing tenant."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
@ -67,8 +72,8 @@ class DatasetRetrievalTestDataFactory:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
@ -76,14 +81,15 @@ class DatasetRetrievalTestDataFactory:
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
name: str = "Test Dataset",
|
||||
@ -101,12 +107,14 @@ class DatasetRetrievalTestDataFactory:
|
||||
provider="vendor",
|
||||
retrieval_model={"top_k": 2},
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission:
|
||||
def create_dataset_permission(
|
||||
db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str
|
||||
) -> DatasetPermission:
|
||||
"""Create a dataset permission."""
|
||||
permission = DatasetPermission(
|
||||
dataset_id=dataset_id,
|
||||
@ -114,12 +122,14 @@ class DatasetRetrievalTestDataFactory:
|
||||
account_id=account_id,
|
||||
has_permission=True,
|
||||
)
|
||||
db.session.add(permission)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(permission)
|
||||
db_session_with_containers.commit()
|
||||
return permission
|
||||
|
||||
@staticmethod
|
||||
def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule:
|
||||
def create_process_rule(
|
||||
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict
|
||||
) -> DatasetProcessRule:
|
||||
"""Create a dataset process rule."""
|
||||
process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset_id,
|
||||
@ -127,12 +137,14 @@ class DatasetRetrievalTestDataFactory:
|
||||
mode=mode,
|
||||
rules=json.dumps(rules),
|
||||
)
|
||||
db.session.add(process_rule)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(process_rule)
|
||||
db_session_with_containers.commit()
|
||||
return process_rule
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery:
|
||||
def create_dataset_query(
|
||||
db_session_with_containers: Session, dataset_id: str, created_by: str, content: str
|
||||
) -> DatasetQuery:
|
||||
"""Create a dataset query."""
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset_id,
|
||||
@ -142,23 +154,23 @@ class DatasetRetrievalTestDataFactory:
|
||||
created_by_role="account",
|
||||
created_by=created_by,
|
||||
)
|
||||
db.session.add(dataset_query)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset_query)
|
||||
db_session_with_containers.commit()
|
||||
return dataset_query
|
||||
|
||||
@staticmethod
|
||||
def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin:
|
||||
def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin:
|
||||
"""Create an app-dataset join."""
|
||||
join = AppDatasetJoin(
|
||||
app_id=str(uuid4()),
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
return join
|
||||
|
||||
@staticmethod
|
||||
def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag:
|
||||
def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag:
|
||||
"""Create a knowledge tag and bind it to the target dataset."""
|
||||
tag = Tag(
|
||||
tenant_id=tenant_id,
|
||||
@ -166,8 +178,8 @@ class DatasetRetrievalTestDataFactory:
|
||||
name=f"tag-{uuid4()}",
|
||||
created_by=created_by,
|
||||
)
|
||||
db.session.add(tag)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(tag)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
binding = TagBinding(
|
||||
tenant_id=tenant_id,
|
||||
@ -175,8 +187,8 @@ class DatasetRetrievalTestDataFactory:
|
||||
target_id=target_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
return tag
|
||||
|
||||
|
||||
@ -195,15 +207,16 @@ class TestDatasetServiceGetDatasets:
|
||||
|
||||
# ==================== Basic Retrieval Tests ====================
|
||||
|
||||
def test_get_datasets_basic_pagination(self, db_session_with_containers):
|
||||
def test_get_datasets_basic_pagination(self, db_session_with_containers: Session):
|
||||
"""Test basic pagination without user or filters."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
for i in range(5):
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name=f"Dataset {i}",
|
||||
@ -217,21 +230,23 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 5
|
||||
assert total == 5
|
||||
|
||||
def test_get_datasets_with_search(self, db_session_with_containers):
|
||||
def test_get_datasets_with_search(self, db_session_with_containers: Session):
|
||||
"""Test get_datasets with search keyword."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
page = 1
|
||||
per_page = 20
|
||||
search = "test"
|
||||
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Test Dataset",
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name="Another Dataset",
|
||||
@ -245,26 +260,32 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_with_tag_filtering(self, db_session_with_containers):
|
||||
def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session):
|
||||
"""Test get_datasets with tag_ids filtering."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
dataset_1 = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
dataset_2 = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
|
||||
tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id)
|
||||
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id)
|
||||
tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(
|
||||
db_session_with_containers, tenant.id, account.id, dataset_1.id
|
||||
)
|
||||
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(
|
||||
db_session_with_containers, tenant.id, account.id, dataset_2.id
|
||||
)
|
||||
tag_ids = [tag_1.id, tag_2.id]
|
||||
|
||||
# Act
|
||||
@ -274,16 +295,17 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 2
|
||||
assert total == 2
|
||||
|
||||
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers):
|
||||
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session):
|
||||
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
page = 1
|
||||
per_page = 20
|
||||
tag_ids = []
|
||||
|
||||
for i in range(3):
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
name=f"dataset-{i}",
|
||||
@ -300,19 +322,21 @@ class TestDatasetServiceGetDatasets:
|
||||
|
||||
# ==================== Permission-Based Filtering Tests ====================
|
||||
|
||||
def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers):
|
||||
def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session):
|
||||
"""Test that without user, only ALL_TEAM datasets are shown."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
)
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=account.id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
@ -325,15 +349,18 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_owner_with_include_all(self, db_session_with_containers):
|
||||
def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session):
|
||||
"""Test that OWNER with include_all=True sees all datasets."""
|
||||
# Arrange
|
||||
owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
for i, permission in enumerate(
|
||||
[DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM]
|
||||
):
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
name=f"dataset-{i}",
|
||||
@ -353,12 +380,15 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 3
|
||||
assert total == 3
|
||||
|
||||
def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers):
|
||||
def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session):
|
||||
"""Test that normal user sees ONLY_ME datasets they created."""
|
||||
# Arrange
|
||||
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
@ -371,13 +401,18 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers):
|
||||
def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session):
|
||||
"""Test that normal user sees ALL_TEAM datasets."""
|
||||
# Arrange
|
||||
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
|
||||
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
|
||||
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
@ -390,18 +425,25 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers):
|
||||
def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session):
|
||||
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
|
||||
# Arrange
|
||||
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
|
||||
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
|
||||
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
|
||||
)
|
||||
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
||||
)
|
||||
DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id)
|
||||
DatasetRetrievalTestDataFactory.create_dataset_permission(
|
||||
db_session_with_containers, dataset.id, tenant.id, user.id
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
|
||||
@ -410,20 +452,25 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers):
|
||||
def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session):
|
||||
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
|
||||
# Arrange
|
||||
operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.DATASET_OPERATOR
|
||||
db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
|
||||
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
|
||||
)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
|
||||
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id)
|
||||
DatasetRetrievalTestDataFactory.create_dataset_permission(
|
||||
db_session_with_containers, dataset.id, tenant.id, operator.id
|
||||
)
|
||||
|
||||
# Act
|
||||
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
|
||||
@ -432,14 +479,17 @@ class TestDatasetServiceGetDatasets:
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers):
|
||||
def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session):
|
||||
"""Test that DATASET_OPERATOR without permissions returns empty result."""
|
||||
# Arrange
|
||||
operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
|
||||
role=TenantAccountRole.DATASET_OPERATOR
|
||||
db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR
|
||||
)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
|
||||
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
|
||||
)
|
||||
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
permission=DatasetPermissionEnum.ALL_TEAM,
|
||||
@ -456,11 +506,13 @@ class TestDatasetServiceGetDatasets:
|
||||
class TestDatasetServiceGetDataset:
|
||||
"""Comprehensive integration tests for DatasetService.get_dataset method."""
|
||||
|
||||
def test_get_dataset_success(self, db_session_with_containers):
|
||||
def test_get_dataset_success(self, db_session_with_containers: Session):
|
||||
"""Test successful retrieval of a single dataset."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_dataset(dataset.id)
|
||||
@ -469,7 +521,7 @@ class TestDatasetServiceGetDataset:
|
||||
assert result is not None
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_get_dataset_not_found(self, db_session_with_containers):
|
||||
def test_get_dataset_not_found(self, db_session_with_containers: Session):
|
||||
"""Test retrieval when dataset doesn't exist."""
|
||||
# Arrange
|
||||
dataset_id = str(uuid4())
|
||||
@ -484,12 +536,15 @@ class TestDatasetServiceGetDataset:
|
||||
class TestDatasetServiceGetDatasetsByIds:
|
||||
"""Comprehensive integration tests for DatasetService.get_datasets_by_ids method."""
|
||||
|
||||
def test_get_datasets_by_ids_success(self, db_session_with_containers):
|
||||
def test_get_datasets_by_ids_success(self, db_session_with_containers: Session):
|
||||
"""Test successful bulk retrieval of datasets by IDs."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
datasets = [
|
||||
DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3)
|
||||
DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
dataset_ids = [dataset.id for dataset in datasets]
|
||||
|
||||
@ -501,7 +556,7 @@ class TestDatasetServiceGetDatasetsByIds:
|
||||
assert total == 3
|
||||
assert all(dataset.id in dataset_ids for dataset in result_datasets)
|
||||
|
||||
def test_get_datasets_by_ids_empty_list(self, db_session_with_containers):
|
||||
def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session):
|
||||
"""Test get_datasets_by_ids with empty list returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
@ -514,7 +569,7 @@ class TestDatasetServiceGetDatasetsByIds:
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_datasets_by_ids_none_list(self, db_session_with_containers):
|
||||
def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session):
|
||||
"""Test get_datasets_by_ids with None returns empty result."""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
@ -530,17 +585,20 @@ class TestDatasetServiceGetDatasetsByIds:
|
||||
class TestDatasetServiceGetProcessRules:
|
||||
"""Comprehensive integration tests for DatasetService.get_process_rules method."""
|
||||
|
||||
def test_get_process_rules_with_existing_rule(self, db_session_with_containers):
|
||||
def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session):
|
||||
"""Test retrieval of process rules when rule exists."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
|
||||
rules_data = {
|
||||
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
|
||||
"segmentation": {"delimiter": "\n", "max_tokens": 500},
|
||||
}
|
||||
DatasetRetrievalTestDataFactory.create_process_rule(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
created_by=account.id,
|
||||
mode="custom",
|
||||
@ -554,11 +612,13 @@ class TestDatasetServiceGetProcessRules:
|
||||
assert result["mode"] == "custom"
|
||||
assert result["rules"] == rules_data
|
||||
|
||||
def test_get_process_rules_without_existing_rule(self, db_session_with_containers):
|
||||
def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session):
|
||||
"""Test retrieval of process rules when no rule exists (returns defaults)."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_process_rules(dataset.id)
|
||||
@ -572,16 +632,19 @@ class TestDatasetServiceGetProcessRules:
|
||||
class TestDatasetServiceGetDatasetQueries:
|
||||
"""Comprehensive integration tests for DatasetService.get_dataset_queries method."""
|
||||
|
||||
def test_get_dataset_queries_success(self, db_session_with_containers):
|
||||
def test_get_dataset_queries_success(self, db_session_with_containers: Session):
|
||||
"""Test successful retrieval of dataset queries."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
for i in range(3):
|
||||
DatasetRetrievalTestDataFactory.create_dataset_query(
|
||||
db_session_with_containers,
|
||||
dataset_id=dataset.id,
|
||||
created_by=account.id,
|
||||
content=f"query-{i}",
|
||||
@ -595,11 +658,13 @@ class TestDatasetServiceGetDatasetQueries:
|
||||
assert total == 3
|
||||
assert all(query.dataset_id == dataset.id for query in queries)
|
||||
|
||||
def test_get_dataset_queries_empty_result(self, db_session_with_containers):
|
||||
def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session):
|
||||
"""Test retrieval when no queries exist."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
page = 1
|
||||
per_page = 20
|
||||
|
||||
@ -614,14 +679,16 @@ class TestDatasetServiceGetDatasetQueries:
|
||||
class TestDatasetServiceGetRelatedApps:
|
||||
"""Comprehensive integration tests for DatasetService.get_related_apps method."""
|
||||
|
||||
def test_get_related_apps_success(self, db_session_with_containers):
|
||||
def test_get_related_apps_success(self, db_session_with_containers: Session):
|
||||
"""Test successful retrieval of related apps."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
|
||||
for _ in range(2):
|
||||
DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id)
|
||||
DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id)
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset.id)
|
||||
@ -630,11 +697,13 @@ class TestDatasetServiceGetRelatedApps:
|
||||
assert len(result) == 2
|
||||
assert all(join.dataset_id == dataset.id for join in result)
|
||||
|
||||
def test_get_related_apps_empty_result(self, db_session_with_containers):
|
||||
def test_get_related_apps_empty_result(self, db_session_with_containers: Session):
|
||||
"""Test retrieval when no related apps exist."""
|
||||
# Arrange
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
|
||||
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetRetrievalTestDataFactory.create_dataset(
|
||||
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
|
||||
)
|
||||
|
||||
# Act
|
||||
result = DatasetService.get_related_apps(dataset.id)
|
||||
|
||||
@ -2,9 +2,9 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
@ -15,7 +15,9 @@ class DatasetUpdateTestDataFactory:
|
||||
"""Factory class for creating real test data for dataset update integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
|
||||
def create_account_with_tenant(
|
||||
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER
|
||||
) -> tuple[Account, Tenant]:
|
||||
"""Create a real account and tenant with the given role."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
@ -23,12 +25,12 @@ class DatasetUpdateTestDataFactory:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
@ -36,14 +38,15 @@ class DatasetUpdateTestDataFactory:
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
provider: str = "vendor",
|
||||
@ -71,12 +74,13 @@ class DatasetUpdateTestDataFactory:
|
||||
embedding_model=embedding_model,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding(
|
||||
db_session_with_containers: Session,
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
@ -93,8 +97,8 @@ class DatasetUpdateTestDataFactory:
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
return binding
|
||||
|
||||
|
||||
@ -112,10 +116,11 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
# ==================== External Dataset Tests ====================
|
||||
|
||||
def test_update_external_dataset_success(self, db_session_with_containers):
|
||||
def test_update_external_dataset_success(self, db_session_with_containers: Session):
|
||||
"""Test successful update of external dataset."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
@ -124,12 +129,13 @@ class TestDatasetServiceUpdateDataset:
|
||||
retrieval_model="old_model",
|
||||
)
|
||||
binding = DatasetUpdateTestDataFactory.create_external_binding(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
binding_id = binding.id
|
||||
db.session.expunge(binding)
|
||||
db_session_with_containers.expunge(binding)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
@ -142,8 +148,8 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
db.session.refresh(dataset)
|
||||
updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
|
||||
db_session_with_containers.refresh(dataset)
|
||||
updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description == "new_description"
|
||||
@ -153,15 +159,17 @@ class TestDatasetServiceUpdateDataset:
|
||||
assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
|
||||
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session):
|
||||
"""Test error when external knowledge id is missing."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
)
|
||||
DatasetUpdateTestDataFactory.create_external_binding(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=user.id,
|
||||
@ -173,17 +181,19 @@ class TestDatasetServiceUpdateDataset:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "External knowledge id is required" in str(context.value)
|
||||
db.session.rollback()
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers):
|
||||
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session):
|
||||
"""Test error when external knowledge api id is missing."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
)
|
||||
DatasetUpdateTestDataFactory.create_external_binding(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=user.id,
|
||||
@ -195,12 +205,13 @@ class TestDatasetServiceUpdateDataset:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "External knowledge api id is required" in str(context.value)
|
||||
db.session.rollback()
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers):
|
||||
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
@ -216,15 +227,16 @@ class TestDatasetServiceUpdateDataset:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "External knowledge binding not found" in str(context.value)
|
||||
db.session.rollback()
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
# ==================== Internal Dataset Basic Tests ====================
|
||||
|
||||
def test_update_internal_dataset_basic_success(self, db_session_with_containers):
|
||||
def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session):
|
||||
"""Test successful update of internal dataset with basic fields."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
@ -244,7 +256,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description == "new_description"
|
||||
@ -254,11 +266,12 @@ class TestDatasetServiceUpdateDataset:
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers):
|
||||
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session):
|
||||
"""Test that None values are filtered out except for description field."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
@ -278,7 +291,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description is None
|
||||
@ -289,11 +302,12 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
# ==================== Indexing Technique Switch Tests ====================
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers):
|
||||
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session):
|
||||
"""Test updating internal dataset indexing technique to economy."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
@ -312,7 +326,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "remove")
|
||||
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.indexing_technique == "economy"
|
||||
assert dataset.embedding_model is None
|
||||
assert dataset.embedding_model_provider is None
|
||||
@ -320,10 +334,11 @@ class TestDatasetServiceUpdateDataset:
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers):
|
||||
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session):
|
||||
"""Test updating internal dataset indexing technique to high_quality."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
@ -366,7 +381,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002")
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "add")
|
||||
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
@ -380,9 +395,10 @@ class TestDatasetServiceUpdateDataset:
|
||||
self, db_session_with_containers
|
||||
):
|
||||
"""Test preserving embedding settings when indexing technique remains unchanged."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
@ -399,7 +415,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
@ -409,11 +425,12 @@ class TestDatasetServiceUpdateDataset:
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers):
|
||||
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session):
|
||||
"""Test updating internal dataset with new embedding model."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
@ -465,7 +482,7 @@ class TestDatasetServiceUpdateDataset:
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.embedding_model == "text-embedding-3-small"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == binding.id
|
||||
@ -474,9 +491,9 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_update_dataset_not_found_error(self, db_session_with_containers):
|
||||
def test_update_dataset_not_found_error(self, db_session_with_containers: Session):
|
||||
"""Test error when dataset is not found."""
|
||||
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
@ -484,11 +501,16 @@ class TestDatasetServiceUpdateDataset:
|
||||
|
||||
assert "Dataset not found" in str(context.value)
|
||||
|
||||
def test_update_dataset_permission_error(self, db_session_with_containers):
|
||||
def test_update_dataset_permission_error(self, db_session_with_containers: Session):
|
||||
"""Test error when user doesn't have permission."""
|
||||
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.OWNER
|
||||
)
|
||||
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(
|
||||
db_session_with_containers, role=TenantAccountRole.NORMAL
|
||||
)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
provider="vendor",
|
||||
@ -500,10 +522,11 @@ class TestDatasetServiceUpdateDataset:
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset(dataset.id, update_data, outsider)
|
||||
|
||||
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers):
|
||||
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session):
|
||||
"""Test error when embedding model is not available."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import create_autospec, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -19,7 +20,7 @@ class TestFileService:
|
||||
"""Integration tests for FileService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, db_session_with_containers):
|
||||
def engine(self, db_session_with_containers: Session):
|
||||
bind = db_session_with_containers.get_bind()
|
||||
assert isinstance(bind, Engine)
|
||||
return bind
|
||||
@ -46,7 +47,7 @@ class TestFileService:
|
||||
"extract_processor": mock_extract_processor,
|
||||
}
|
||||
|
||||
def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account for testing.
|
||||
|
||||
@ -67,18 +68,16 @@ class TestFileService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
@ -89,15 +88,15 @@ class TestFileService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
@ -118,14 +117,14 @@ class TestFileService:
|
||||
session_id=fake.uuid4(),
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account):
|
||||
def _create_test_upload_file(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, account
|
||||
):
|
||||
"""
|
||||
Helper method to create a test upload file for testing.
|
||||
|
||||
@ -155,15 +154,13 @@ class TestFileService:
|
||||
source_url="",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
# Test upload_file method
|
||||
def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful file upload with valid parameters.
|
||||
"""
|
||||
@ -196,7 +193,9 @@ class TestFileService:
|
||||
|
||||
assert upload_file.id is not None
|
||||
|
||||
def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_with_end_user(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with end user instead of account.
|
||||
"""
|
||||
@ -219,7 +218,7 @@ class TestFileService:
|
||||
assert upload_file.created_by_role == CreatorUserRole.END_USER
|
||||
|
||||
def test_upload_file_with_datasets_source(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with datasets source parameter.
|
||||
@ -244,7 +243,7 @@ class TestFileService:
|
||||
assert upload_file.source_url == "https://example.com/source"
|
||||
|
||||
def test_upload_file_invalid_filename_characters(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with invalid filename characters.
|
||||
@ -265,7 +264,7 @@ class TestFileService:
|
||||
)
|
||||
|
||||
def test_upload_file_filename_too_long(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with filename that exceeds length limit.
|
||||
@ -295,7 +294,7 @@ class TestFileService:
|
||||
assert len(base_name) <= 200
|
||||
|
||||
def test_upload_file_datasets_unsupported_type(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload for datasets with unsupported file type.
|
||||
@ -316,7 +315,9 @@ class TestFileService:
|
||||
source="datasets",
|
||||
)
|
||||
|
||||
def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_too_large(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with file size exceeding limit.
|
||||
"""
|
||||
@ -338,7 +339,7 @@ class TestFileService:
|
||||
|
||||
# Test is_file_size_within_limit method
|
||||
def test_is_file_size_within_limit_image_success(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for image files within limit.
|
||||
@ -351,7 +352,7 @@ class TestFileService:
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_video_success(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for video files within limit.
|
||||
@ -364,7 +365,7 @@ class TestFileService:
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_audio_success(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for audio files within limit.
|
||||
@ -377,7 +378,7 @@ class TestFileService:
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_document_success(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for document files within limit.
|
||||
@ -390,7 +391,7 @@ class TestFileService:
|
||||
assert result is True
|
||||
|
||||
def test_is_file_size_within_limit_image_exceeded(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for image files exceeding limit.
|
||||
@ -403,7 +404,7 @@ class TestFileService:
|
||||
assert result is False
|
||||
|
||||
def test_is_file_size_within_limit_unknown_extension(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size check for unknown file extension.
|
||||
@ -416,7 +417,7 @@ class TestFileService:
|
||||
assert result is True
|
||||
|
||||
# Test upload_text method
|
||||
def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful text upload.
|
||||
"""
|
||||
@ -447,7 +448,9 @@ class TestFileService:
|
||||
# Verify storage was called
|
||||
mock_external_service_dependencies["storage"].save.assert_called_once()
|
||||
|
||||
def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_text_name_too_long(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test text upload with name that exceeds length limit.
|
||||
"""
|
||||
@ -472,7 +475,9 @@ class TestFileService:
|
||||
assert upload_file.name == "a" * 200
|
||||
|
||||
# Test get_file_preview method
|
||||
def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_get_file_preview_success(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful file preview generation.
|
||||
"""
|
||||
@ -484,9 +489,8 @@ class TestFileService:
|
||||
|
||||
# Update file to have document extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = FileService(engine).get_file_preview(file_id=upload_file.id)
|
||||
|
||||
@ -494,7 +498,7 @@ class TestFileService:
|
||||
mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once()
|
||||
|
||||
def test_get_file_preview_file_not_found(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file preview with non-existent file.
|
||||
@ -506,7 +510,7 @@ class TestFileService:
|
||||
FileService(engine).get_file_preview(file_id=non_existent_id)
|
||||
|
||||
def test_get_file_preview_unsupported_file_type(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file preview with unsupported file type.
|
||||
@ -519,15 +523,14 @@ class TestFileService:
|
||||
|
||||
# Update file to have non-document extension
|
||||
upload_file.extension = "jpg"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
FileService(engine).get_file_preview(file_id=upload_file.id)
|
||||
|
||||
def test_get_file_preview_text_truncation(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file preview with text that exceeds preview limit.
|
||||
@ -540,9 +543,8 @@ class TestFileService:
|
||||
|
||||
# Update file to have document extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock long text content
|
||||
long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT
|
||||
@ -554,7 +556,9 @@ class TestFileService:
|
||||
assert result == "x" * 3000
|
||||
|
||||
# Test get_image_preview method
|
||||
def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_get_image_preview_success(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful image preview generation.
|
||||
"""
|
||||
@ -566,9 +570,8 @@ class TestFileService:
|
||||
|
||||
# Update file to have image extension
|
||||
upload_file.extension = "jpg"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
@ -586,7 +589,7 @@ class TestFileService:
|
||||
mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once()
|
||||
|
||||
def test_get_image_preview_invalid_signature(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test image preview with invalid signature.
|
||||
@ -613,7 +616,7 @@ class TestFileService:
|
||||
)
|
||||
|
||||
def test_get_image_preview_file_not_found(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test image preview with non-existent file.
|
||||
@ -634,7 +637,7 @@ class TestFileService:
|
||||
)
|
||||
|
||||
def test_get_image_preview_unsupported_file_type(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test image preview with non-image file type.
|
||||
@ -647,9 +650,8 @@ class TestFileService:
|
||||
|
||||
# Update file to have non-image extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
timestamp = "1234567890"
|
||||
nonce = "test_nonce"
|
||||
@ -665,7 +667,7 @@ class TestFileService:
|
||||
|
||||
# Test get_file_generator_by_file_id method
|
||||
def test_get_file_generator_by_file_id_success(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful file generator retrieval.
|
||||
@ -692,7 +694,7 @@ class TestFileService:
|
||||
mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once()
|
||||
|
||||
def test_get_file_generator_by_file_id_invalid_signature(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file generator retrieval with invalid signature.
|
||||
@ -719,7 +721,7 @@ class TestFileService:
|
||||
)
|
||||
|
||||
def test_get_file_generator_by_file_id_file_not_found(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file generator retrieval with non-existent file.
|
||||
@ -741,7 +743,7 @@ class TestFileService:
|
||||
|
||||
# Test get_public_image_preview method
|
||||
def test_get_public_image_preview_success(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful public image preview generation.
|
||||
@ -754,9 +756,8 @@ class TestFileService:
|
||||
|
||||
# Update file to have image extension
|
||||
upload_file.extension = "jpg"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id)
|
||||
|
||||
@ -765,7 +766,7 @@ class TestFileService:
|
||||
mock_external_service_dependencies["storage"].load.assert_called_once()
|
||||
|
||||
def test_get_public_image_preview_file_not_found(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test public image preview with non-existent file.
|
||||
@ -777,7 +778,7 @@ class TestFileService:
|
||||
FileService(engine).get_public_image_preview(file_id=non_existent_id)
|
||||
|
||||
def test_get_public_image_preview_unsupported_file_type(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test public image preview with non-image file type.
|
||||
@ -790,15 +791,16 @@ class TestFileService:
|
||||
|
||||
# Update file to have non-image extension
|
||||
upload_file.extension = "pdf"
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
FileService(engine).get_public_image_preview(file_id=upload_file.id)
|
||||
|
||||
# Test edge cases and boundary conditions
|
||||
def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_empty_content(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with empty content.
|
||||
"""
|
||||
@ -820,7 +822,7 @@ class TestFileService:
|
||||
assert upload_file.size == 0
|
||||
|
||||
def test_upload_file_special_characters_in_name(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with special characters in filename (but valid ones).
|
||||
@ -843,7 +845,7 @@ class TestFileService:
|
||||
assert upload_file.name == filename
|
||||
|
||||
def test_upload_file_different_case_extensions(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with different case extensions.
|
||||
@ -865,7 +867,9 @@ class TestFileService:
|
||||
assert upload_file is not None
|
||||
assert upload_file.extension == "pdf" # Should be converted to lowercase
|
||||
|
||||
def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_text_empty_text(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test text upload with empty text.
|
||||
"""
|
||||
@ -888,7 +892,9 @@ class TestFileService:
|
||||
assert upload_file is not None
|
||||
assert upload_file.size == 0
|
||||
|
||||
def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_file_size_limits_edge_cases(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file size limits with edge case values.
|
||||
"""
|
||||
@ -908,7 +914,9 @@ class TestFileService:
|
||||
result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size)
|
||||
assert result is False
|
||||
|
||||
def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_with_source_url(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with source URL that gets overridden by signed URL.
|
||||
"""
|
||||
@ -946,7 +954,7 @@ class TestFileService:
|
||||
|
||||
# Test file extension blacklist
|
||||
def test_upload_file_blocked_extension(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with blocked extension.
|
||||
@ -969,7 +977,7 @@ class TestFileService:
|
||||
)
|
||||
|
||||
def test_upload_file_blocked_extension_case_insensitive(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with blocked extension (case insensitive).
|
||||
@ -992,7 +1000,9 @@ class TestFileService:
|
||||
user=account,
|
||||
)
|
||||
|
||||
def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_not_in_blacklist(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with extension not in blacklist.
|
||||
"""
|
||||
@ -1016,7 +1026,9 @@ class TestFileService:
|
||||
assert upload_file.name == filename
|
||||
assert upload_file.extension == "pdf"
|
||||
|
||||
def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
|
||||
def test_upload_file_empty_blacklist(
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with empty blacklist (default behavior).
|
||||
"""
|
||||
@ -1041,7 +1053,7 @@ class TestFileService:
|
||||
assert upload_file.extension == "sh"
|
||||
|
||||
def test_upload_file_multiple_blocked_extensions(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with multiple blocked extensions.
|
||||
@ -1066,7 +1078,7 @@ class TestFileService:
|
||||
)
|
||||
|
||||
def test_upload_file_no_extension_with_blacklist(
|
||||
self, db_session_with_containers, engine, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test file upload with no extension when blacklist is configured.
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
@ -69,7 +70,7 @@ class TestMessageService:
|
||||
# "current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -127,11 +128,10 @@ class TestMessageService:
|
||||
# mock_external_service_dependencies["current_user"].id = account_id
|
||||
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
|
||||
|
||||
def _create_test_conversation(self, app, account, fake):
|
||||
def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake):
|
||||
"""
|
||||
Helper method to create a test conversation with all required fields.
|
||||
"""
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
|
||||
conversation = Conversation(
|
||||
@ -153,17 +153,16 @@ class TestMessageService:
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.flush()
|
||||
return conversation
|
||||
|
||||
def _create_test_message(self, app, conversation, account, fake):
|
||||
def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake):
|
||||
"""
|
||||
Helper method to create a test message with all required fields.
|
||||
"""
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
message = Message(
|
||||
@ -192,11 +191,13 @@ class TestMessageService:
|
||||
from_account_id=account.id,
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
return message
|
||||
|
||||
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_first_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by first ID.
|
||||
"""
|
||||
@ -204,10 +205,10 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination by first ID
|
||||
@ -228,7 +229,9 @@ class TestMessageService:
|
||||
# Verify messages are in ascending order
|
||||
assert result.data[0].created_at <= result.data[1].created_at
|
||||
|
||||
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_first_id_no_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID when no user is provided.
|
||||
"""
|
||||
@ -246,7 +249,7 @@ class TestMessageService:
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_first_id_no_conversation_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID when no conversation ID is provided.
|
||||
@ -265,7 +268,7 @@ class TestMessageService:
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_first_id_invalid_first_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by first ID with invalid first_id.
|
||||
@ -274,8 +277,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Test pagination with invalid first_id
|
||||
with pytest.raises(FirstMessageNotExistsError):
|
||||
@ -287,7 +290,9 @@ class TestMessageService:
|
||||
limit=10,
|
||||
)
|
||||
|
||||
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_last_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID.
|
||||
"""
|
||||
@ -295,10 +300,10 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination by last ID
|
||||
@ -319,7 +324,7 @@ class TestMessageService:
|
||||
assert result.data[0].created_at >= result.data[1].created_at
|
||||
|
||||
def test_pagination_by_last_id_with_include_ids(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with include_ids filter.
|
||||
@ -328,10 +333,10 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and multiple messages
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
messages = []
|
||||
for i in range(5):
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
messages.append(message)
|
||||
|
||||
# Test pagination with include_ids
|
||||
@ -347,7 +352,9 @@ class TestMessageService:
|
||||
for message in result.data:
|
||||
assert message.id in include_ids
|
||||
|
||||
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_last_id_no_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID when no user is provided.
|
||||
"""
|
||||
@ -363,7 +370,7 @@ class TestMessageService:
|
||||
assert result.has_more is False
|
||||
|
||||
def test_pagination_by_last_id_invalid_last_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with invalid last_id.
|
||||
@ -372,8 +379,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Test pagination with invalid last_id
|
||||
with pytest.raises(LastMessageNotExistsError):
|
||||
@ -385,7 +392,7 @@ class TestMessageService:
|
||||
conversation_id=conversation.id,
|
||||
)
|
||||
|
||||
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful creation of feedback.
|
||||
"""
|
||||
@ -393,8 +400,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create feedback
|
||||
rating = "like"
|
||||
@ -413,7 +420,7 @@ class TestMessageService:
|
||||
assert feedback.from_account_id == account.id
|
||||
assert feedback.from_end_user_id is None
|
||||
|
||||
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test creating feedback when no user is provided.
|
||||
"""
|
||||
@ -421,8 +428,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Test creating feedback with no user
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
@ -430,7 +437,9 @@ class TestMessageService:
|
||||
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
|
||||
)
|
||||
|
||||
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_feedback_update_existing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test updating existing feedback.
|
||||
"""
|
||||
@ -438,8 +447,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
initial_rating = "like"
|
||||
@ -462,7 +471,9 @@ class TestMessageService:
|
||||
assert updated_feedback.rating != initial_rating
|
||||
assert updated_feedback.content != initial_content
|
||||
|
||||
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_feedback_delete_existing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test deleting existing feedback by setting rating to None.
|
||||
"""
|
||||
@ -470,8 +481,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
feedback = MessageService.create_feedback(
|
||||
@ -482,13 +493,14 @@ class TestMessageService:
|
||||
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
|
||||
|
||||
# Verify feedback was deleted
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
|
||||
deleted_feedback = (
|
||||
db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
|
||||
)
|
||||
assert deleted_feedback is None
|
||||
|
||||
def test_create_feedback_no_rating_when_not_exists(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creating feedback with no rating when feedback doesn't exist.
|
||||
@ -497,8 +509,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Test creating feedback with no rating when no feedback exists
|
||||
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
|
||||
@ -506,7 +518,9 @@ class TestMessageService:
|
||||
app_model=app, message_id=message.id, user=account, rating=None, content=None
|
||||
)
|
||||
|
||||
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_all_messages_feedbacks_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of all message feedbacks.
|
||||
"""
|
||||
@ -516,8 +530,8 @@ class TestMessageService:
|
||||
# Create multiple conversations and messages with feedbacks
|
||||
feedbacks = []
|
||||
for i in range(3):
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app,
|
||||
@ -539,7 +553,7 @@ class TestMessageService:
|
||||
assert result[i]["created_at"] >= result[i + 1]["created_at"]
|
||||
|
||||
def test_get_all_messages_feedbacks_pagination(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of message feedbacks.
|
||||
@ -549,8 +563,8 @@ class TestMessageService:
|
||||
|
||||
# Create multiple conversations and messages with feedbacks
|
||||
for i in range(5):
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
|
||||
@ -569,7 +583,7 @@ class TestMessageService:
|
||||
page_2_ids = {feedback["id"] for feedback in result_page_2}
|
||||
assert len(page_1_ids.intersection(page_2_ids)) == 0
|
||||
|
||||
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of message.
|
||||
"""
|
||||
@ -577,8 +591,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Get message
|
||||
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
|
||||
@ -590,7 +604,7 @@ class TestMessageService:
|
||||
assert retrieved_message.from_source == "console"
|
||||
assert retrieved_message.from_account_id == account.id
|
||||
|
||||
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting message that doesn't exist.
|
||||
"""
|
||||
@ -601,7 +615,7 @@ class TestMessageService:
|
||||
with pytest.raises(MessageNotExistsError):
|
||||
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
|
||||
|
||||
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting message with wrong user (different account).
|
||||
"""
|
||||
@ -609,8 +623,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create another account
|
||||
from services.account_service import AccountService, TenantService
|
||||
@ -628,7 +642,7 @@ class TestMessageService:
|
||||
MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
|
||||
|
||||
def test_get_suggested_questions_after_answer_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful generation of suggested questions after answer.
|
||||
@ -637,8 +651,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Mock the LLMGenerator to return specific questions
|
||||
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
|
||||
@ -665,7 +679,7 @@ class TestMessageService:
|
||||
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
|
||||
|
||||
def test_get_suggested_questions_after_answer_no_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when no user is provided.
|
||||
@ -674,8 +688,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Test getting suggested questions with no user
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -686,7 +700,7 @@ class TestMessageService:
|
||||
)
|
||||
|
||||
def test_get_suggested_questions_after_answer_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when feature is disabled.
|
||||
@ -695,8 +709,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Mock the feature to be disabled
|
||||
mock_external_service_dependencies[
|
||||
@ -712,7 +726,7 @@ class TestMessageService:
|
||||
)
|
||||
|
||||
def test_get_suggested_questions_after_answer_no_workflow(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions when no workflow exists.
|
||||
@ -721,8 +735,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Mock no workflow
|
||||
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
|
||||
@ -738,7 +752,7 @@ class TestMessageService:
|
||||
assert result == []
|
||||
|
||||
def test_get_suggested_questions_after_answer_debugger_mode(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting suggested questions in debugger mode.
|
||||
@ -747,8 +761,8 @@ class TestMessageService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
# Create a conversation and message
|
||||
conversation = self._create_test_conversation(app, account, fake)
|
||||
message = self._create_test_message(app, conversation, account, fake)
|
||||
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Mock questions
|
||||
mock_questions = ["Debug question 1", "Debug question 2"]
|
||||
|
||||
@ -6,9 +6,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
@ -40,25 +40,25 @@ class TestMessagesCleanServiceIntegration:
|
||||
PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
"""Clean up database before and after each test to ensure isolation."""
|
||||
yield
|
||||
# Clear all test data in correct order (respecting foreign key constraints)
|
||||
db.session.query(DatasetRetrieverResource).delete()
|
||||
db.session.query(AppAnnotationHitHistory).delete()
|
||||
db.session.query(SavedMessage).delete()
|
||||
db.session.query(MessageFile).delete()
|
||||
db.session.query(MessageAgentThought).delete()
|
||||
db.session.query(MessageChain).delete()
|
||||
db.session.query(MessageAnnotation).delete()
|
||||
db.session.query(MessageFeedback).delete()
|
||||
db.session.query(Message).delete()
|
||||
db.session.query(Conversation).delete()
|
||||
db.session.query(App).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_redis(self):
|
||||
@ -100,7 +100,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False):
|
||||
yield
|
||||
|
||||
def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX):
|
||||
def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX):
|
||||
"""Helper to create account and tenant."""
|
||||
fake = Faker()
|
||||
|
||||
@ -110,28 +110,28 @@ class TestMessagesCleanServiceIntegration:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
plan=str(plan),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
tenant_account_join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
)
|
||||
db.session.add(tenant_account_join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant_account_join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_app(self, tenant, account):
|
||||
def _create_app(self, db_session_with_containers: Session, tenant, account):
|
||||
"""Helper to create an app."""
|
||||
fake = Faker()
|
||||
|
||||
@ -149,12 +149,12 @@ class TestMessagesCleanServiceIntegration:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app
|
||||
|
||||
def _create_conversation(self, app):
|
||||
def _create_conversation(self, db_session_with_containers: Session, app):
|
||||
"""Helper to create a conversation."""
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
@ -168,12 +168,14 @@ class TestMessagesCleanServiceIntegration:
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
def _create_message(self, app, conversation, created_at=None, with_relations=True):
|
||||
def _create_message(
|
||||
self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True
|
||||
):
|
||||
"""Helper to create a message with optional related records."""
|
||||
if created_at is None:
|
||||
created_at = datetime.datetime.now()
|
||||
@ -197,16 +199,16 @@ class TestMessagesCleanServiceIntegration:
|
||||
from_account_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
db.session.add(message)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
if with_relations:
|
||||
self._create_message_relations(message)
|
||||
self._create_message_relations(db_session_with_containers, message)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
return message
|
||||
|
||||
def _create_message_relations(self, message):
|
||||
def _create_message_relations(self, db_session_with_containers: Session, message):
|
||||
"""Helper to create all message-related records."""
|
||||
# MessageFeedback
|
||||
feedback = MessageFeedback(
|
||||
@ -217,7 +219,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(feedback)
|
||||
db_session_with_containers.add(feedback)
|
||||
|
||||
# MessageAnnotation
|
||||
annotation = MessageAnnotation(
|
||||
@ -228,7 +230,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
content="Test annotation",
|
||||
account_id=message.from_account_id,
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db_session_with_containers.add(annotation)
|
||||
|
||||
# MessageChain
|
||||
chain = MessageChain(
|
||||
@ -237,8 +239,8 @@ class TestMessagesCleanServiceIntegration:
|
||||
input=json.dumps({"test": "input"}),
|
||||
output=json.dumps({"test": "output"}),
|
||||
)
|
||||
db.session.add(chain)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(chain)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# MessageFile
|
||||
file = MessageFile(
|
||||
@ -250,7 +252,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
created_by_role="end_user",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(file)
|
||||
db_session_with_containers.add(file)
|
||||
|
||||
# SavedMessage
|
||||
saved = SavedMessage(
|
||||
@ -259,9 +261,9 @@ class TestMessagesCleanServiceIntegration:
|
||||
created_by_role="end_user",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
db.session.add(saved)
|
||||
db_session_with_containers.add(saved)
|
||||
|
||||
db.session.flush()
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# AppAnnotationHitHistory
|
||||
hit = AppAnnotationHitHistory(
|
||||
@ -275,7 +277,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
annotation_question="Test annotation question",
|
||||
annotation_content="Test annotation content",
|
||||
)
|
||||
db.session.add(hit)
|
||||
db_session_with_containers.add(hit)
|
||||
|
||||
# DatasetRetrieverResource
|
||||
resource = DatasetRetrieverResource(
|
||||
@ -296,25 +298,29 @@ class TestMessagesCleanServiceIntegration:
|
||||
retriever_from="dataset",
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
db.session.add(resource)
|
||||
db_session_with_containers.add(resource)
|
||||
|
||||
def test_billing_disabled_deletes_all_messages_in_time_range(
|
||||
self, db_session_with_containers, mock_billing_disabled
|
||||
self, db_session_with_containers: Session, mock_billing_disabled
|
||||
):
|
||||
"""Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan."""
|
||||
# Arrange - Create tenant with messages (plan doesn't matter for billing disabled)
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create messages: in-range (should be deleted) and out-of-range (should be kept)
|
||||
in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0)
|
||||
out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0)
|
||||
|
||||
in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True)
|
||||
in_range_msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True
|
||||
)
|
||||
in_range_msg_id = in_range_msg.id
|
||||
|
||||
out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True)
|
||||
out_of_range_msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True
|
||||
)
|
||||
out_of_range_msg_id = out_of_range_msg.id
|
||||
|
||||
# Act - create_message_clean_policy should return BillingDisabledPolicy
|
||||
@ -336,17 +342,34 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
# In-range message deleted
|
||||
assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0
|
||||
# Out-of-range message kept
|
||||
assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1
|
||||
|
||||
# Related records of in-range message deleted
|
||||
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0
|
||||
assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0
|
||||
assert (
|
||||
db_session_with_containers.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == in_range_msg_id)
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.message_id == in_range_msg_id)
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
# Related records of out-of-range message kept
|
||||
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1
|
||||
assert (
|
||||
db_session_with_containers.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id == out_of_range_msg_id)
|
||||
.count()
|
||||
== 1
|
||||
)
|
||||
|
||||
def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_no_messages_returns_empty_stats(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test cleaning when there are no messages to delete (B1)."""
|
||||
# Arrange
|
||||
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
@ -371,36 +394,42 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["filtered_messages"] == 0
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_mixed_sandbox_and_paid_tenants(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test cleaning with mixed sandbox and paid tenants (B2)."""
|
||||
# Arrange - Create sandbox tenants with expired messages
|
||||
sandbox_tenants = []
|
||||
sandbox_message_ids = []
|
||||
for i in range(2):
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
sandbox_tenants.append(tenant)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create 3 expired messages per sandbox tenant
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
for j in range(3):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
|
||||
msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j)
|
||||
)
|
||||
sandbox_message_ids.append(msg.id)
|
||||
|
||||
# Create paid tenants with expired messages (should NOT be deleted)
|
||||
paid_tenants = []
|
||||
paid_message_ids = []
|
||||
for i in range(2):
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL)
|
||||
paid_tenants.append(tenant)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create 2 expired messages per paid tenant
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
for j in range(2):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
|
||||
msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j)
|
||||
)
|
||||
paid_message_ids.append(msg.id)
|
||||
|
||||
# Mock billing service - return plan and expiration_date
|
||||
@ -442,29 +471,39 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 6
|
||||
|
||||
# Only sandbox messages should be deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
|
||||
# Paid messages should remain
|
||||
assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
|
||||
|
||||
# Related records of sandbox messages should be deleted
|
||||
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0
|
||||
assert (
|
||||
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count()
|
||||
db_session_with_containers.query(MessageFeedback)
|
||||
.where(MessageFeedback.message_id.in_(sandbox_message_ids))
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
db_session_with_containers.query(MessageAnnotation)
|
||||
.where(MessageAnnotation.message_id.in_(sandbox_message_ids))
|
||||
.count()
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_cursor_pagination_multiple_batches(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test cursor pagination works correctly across multiple batches (B3)."""
|
||||
# Arrange - Create sandbox tenant with messages that will span multiple batches
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create 10 expired messages with different timestamps
|
||||
base_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
message_ids = []
|
||||
for i in range(10):
|
||||
msg = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=base_date + datetime.timedelta(hours=i),
|
||||
@ -498,20 +537,22 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 10
|
||||
|
||||
# All messages should be deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0
|
||||
|
||||
def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
|
||||
"""Test dry_run mode does not delete messages (B4)."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create expired messages
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
message_ids = []
|
||||
for i in range(3):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
|
||||
msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i)
|
||||
)
|
||||
message_ids.append(msg.id)
|
||||
|
||||
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
@ -540,21 +581,26 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 0 # But NOT deleted
|
||||
|
||||
# All messages should still exist
|
||||
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3
|
||||
# Related records should also still exist
|
||||
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3
|
||||
assert (
|
||||
db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count()
|
||||
== 3
|
||||
)
|
||||
|
||||
def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_partial_plan_data_safe_default(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test when billing returns partial data, unknown tenants are preserved (B5)."""
|
||||
# Arrange - Create 3 tenants
|
||||
tenants_data = []
|
||||
for i in range(3):
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg = self._create_message(app, conv, created_at=expired_date)
|
||||
msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date)
|
||||
|
||||
tenants_data.append(
|
||||
{
|
||||
@ -600,28 +646,30 @@ class TestMessagesCleanServiceIntegration:
|
||||
|
||||
# Check which messages were deleted
|
||||
assert (
|
||||
db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
|
||||
db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
|
||||
) # Sandbox tenant's message deleted
|
||||
|
||||
assert (
|
||||
db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
|
||||
db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
|
||||
) # Professional tenant's message preserved
|
||||
|
||||
assert (
|
||||
db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
|
||||
db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
|
||||
) # Unknown tenant's message preserved (safe default)
|
||||
|
||||
def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_empty_plan_data_skips_deletion(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test when billing returns empty data, skip deletion entirely (B6)."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg = self._create_message(app, conv, created_at=expired_date)
|
||||
msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date)
|
||||
msg_id = msg.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock billing service to return empty data (simulating failure/no data scenario)
|
||||
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
@ -644,17 +692,20 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
# Message should still exist (safe default - don't delete if plan is unknown)
|
||||
assert db.session.query(Message).where(Message.id == msg_id).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1
|
||||
|
||||
def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_time_range_boundary_behavior(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test that messages are correctly filtered by [start_from, end_before) time range (B7)."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create messages: before range, in range, after range
|
||||
msg_before = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from
|
||||
@ -663,6 +714,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
msg_before_id = msg_before.id
|
||||
|
||||
msg_at_start = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive)
|
||||
@ -671,6 +723,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
msg_at_start_id = msg_at_start.id
|
||||
|
||||
msg_in_range = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range
|
||||
@ -679,6 +732,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
msg_in_range_id = msg_in_range.id
|
||||
|
||||
msg_at_end = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive)
|
||||
@ -687,6 +741,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
msg_at_end_id = msg_at_end.id
|
||||
|
||||
msg_after = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before
|
||||
@ -694,7 +749,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
)
|
||||
msg_after_id = msg_after.id
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock billing service
|
||||
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
@ -722,17 +777,17 @@ class TestMessagesCleanServiceIntegration:
|
||||
|
||||
# Verify specific messages using stored IDs
|
||||
# Before range, kept
|
||||
assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1
|
||||
# At start (inclusive), deleted
|
||||
assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0
|
||||
# In range, deleted
|
||||
assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0
|
||||
# At end (exclusive), kept
|
||||
assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1
|
||||
# After range, kept
|
||||
assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1
|
||||
|
||||
def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
|
||||
"""Test cleaning with different graceful period scenarios (B8)."""
|
||||
# Arrange - Create 5 different tenants with different plan and expiration scenarios
|
||||
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
@ -740,50 +795,60 @@ class TestMessagesCleanServiceIntegration:
|
||||
|
||||
# Scenario 1: Sandbox plan with expiration within graceful period (5 days ago)
|
||||
# Should NOT be deleted
|
||||
account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app1 = self._create_app(tenant1, account1)
|
||||
conv1 = self._create_conversation(app1)
|
||||
account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app1 = self._create_app(db_session_with_containers, tenant1, account1)
|
||||
conv1 = self._create_conversation(db_session_with_containers, app1)
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
|
||||
msg1 = self._create_message(
|
||||
db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False
|
||||
)
|
||||
msg1_id = msg1.id
|
||||
expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period
|
||||
|
||||
# Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago)
|
||||
# Should be deleted
|
||||
account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app2 = self._create_app(tenant2, account2)
|
||||
conv2 = self._create_conversation(app2)
|
||||
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
|
||||
account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app2 = self._create_app(db_session_with_containers, tenant2, account2)
|
||||
conv2 = self._create_conversation(db_session_with_containers, app2)
|
||||
msg2 = self._create_message(
|
||||
db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False
|
||||
)
|
||||
msg2_id = msg2.id
|
||||
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period
|
||||
|
||||
# Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription)
|
||||
# Should be deleted
|
||||
account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app3 = self._create_app(tenant3, account3)
|
||||
conv3 = self._create_conversation(app3)
|
||||
msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False)
|
||||
account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app3 = self._create_app(db_session_with_containers, tenant3, account3)
|
||||
conv3 = self._create_conversation(db_session_with_containers, app3)
|
||||
msg3 = self._create_message(
|
||||
db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False
|
||||
)
|
||||
msg3_id = msg3.id
|
||||
|
||||
# Scenario 4: Non-sandbox plan (professional) with no expiration (future date)
|
||||
# Should NOT be deleted
|
||||
account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL)
|
||||
app4 = self._create_app(tenant4, account4)
|
||||
conv4 = self._create_conversation(app4)
|
||||
msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False)
|
||||
account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL)
|
||||
app4 = self._create_app(db_session_with_containers, tenant4, account4)
|
||||
conv4 = self._create_conversation(db_session_with_containers, app4)
|
||||
msg4 = self._create_message(
|
||||
db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False
|
||||
)
|
||||
msg4_id = msg4.id
|
||||
future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year
|
||||
|
||||
# Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago)
|
||||
# Should NOT be deleted (boundary is exclusive: > graceful_period)
|
||||
account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app5 = self._create_app(tenant5, account5)
|
||||
conv5 = self._create_conversation(app5)
|
||||
msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False)
|
||||
account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app5 = self._create_app(db_session_with_containers, tenant5, account5)
|
||||
conv5 = self._create_conversation(db_session_with_containers, app5)
|
||||
msg5 = self._create_message(
|
||||
db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False
|
||||
)
|
||||
msg5_id = msg5.id
|
||||
expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock billing service with all scenarios
|
||||
plan_map = {
|
||||
@ -832,23 +897,31 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 2
|
||||
|
||||
# Verify each scenario using saved IDs
|
||||
assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
|
||||
assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted
|
||||
assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted
|
||||
assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept
|
||||
assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
|
||||
assert (
|
||||
db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0
|
||||
) # Beyond grace, deleted
|
||||
assert (
|
||||
db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0
|
||||
) # No subscription, deleted
|
||||
assert (
|
||||
db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1
|
||||
) # Professional plan, kept
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
|
||||
|
||||
def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
|
||||
"""Test that whitelisted tenants' messages are not deleted (B9)."""
|
||||
# Arrange - Create 3 sandbox tenants with expired messages
|
||||
tenants_data = []
|
||||
for i in range(3):
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg = self._create_message(app, conv, created_at=expired_date, with_relations=False)
|
||||
msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=expired_date, with_relations=False
|
||||
)
|
||||
|
||||
tenants_data.append(
|
||||
{
|
||||
@ -897,27 +970,33 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 1
|
||||
|
||||
# Verify tenant0's message still exists (whitelisted)
|
||||
assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
|
||||
|
||||
# Verify tenant1's message still exists (whitelisted)
|
||||
assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
|
||||
assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
|
||||
|
||||
# Verify tenant2's message was deleted (not whitelisted)
|
||||
assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
|
||||
|
||||
def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
|
||||
def test_from_days_cleans_old_messages(
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test from_days correctly cleans messages older than N days (B11)."""
|
||||
# Arrange
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
# Create old messages (should be deleted - older than 30 days)
|
||||
old_date = datetime.datetime.now() - datetime.timedelta(days=45)
|
||||
old_msg_ids = []
|
||||
for i in range(3):
|
||||
msg = self._create_message(
|
||||
app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=old_date - datetime.timedelta(hours=i),
|
||||
with_relations=False,
|
||||
)
|
||||
old_msg_ids.append(msg.id)
|
||||
|
||||
@ -926,11 +1005,15 @@ class TestMessagesCleanServiceIntegration:
|
||||
recent_msg_ids = []
|
||||
for i in range(2):
|
||||
msg = self._create_message(
|
||||
app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conv,
|
||||
created_at=recent_date - datetime.timedelta(hours=i),
|
||||
with_relations=False,
|
||||
)
|
||||
recent_msg_ids.append(msg.id)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
|
||||
mock_billing.return_value = {
|
||||
@ -955,30 +1038,34 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 3
|
||||
|
||||
# Old messages deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0
|
||||
# Recent messages kept
|
||||
assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2
|
||||
|
||||
def test_whitelist_precedence_over_grace_period(
|
||||
self, db_session_with_containers, mock_billing_enabled, mock_whitelist
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test that whitelist takes precedence over grace period logic."""
|
||||
# Arrange - Create 2 sandbox tenants
|
||||
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
|
||||
|
||||
# Tenant1: whitelisted, expired beyond grace period
|
||||
account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app1 = self._create_app(tenant1, account1)
|
||||
conv1 = self._create_conversation(app1)
|
||||
account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app1 = self._create_app(db_session_with_containers, tenant1, account1)
|
||||
conv1 = self._create_conversation(db_session_with_containers, app1)
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
|
||||
msg1 = self._create_message(
|
||||
db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False
|
||||
)
|
||||
expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace
|
||||
|
||||
# Tenant2: not whitelisted, within grace period
|
||||
account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app2 = self._create_app(tenant2, account2)
|
||||
conv2 = self._create_conversation(app2)
|
||||
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
|
||||
account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app2 = self._create_app(db_session_with_containers, tenant2, account2)
|
||||
conv2 = self._create_conversation(db_session_with_containers, app2)
|
||||
msg2 = self._create_message(
|
||||
db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False
|
||||
)
|
||||
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace
|
||||
|
||||
# Mock billing service
|
||||
@ -1019,22 +1106,26 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 0
|
||||
|
||||
# Verify both messages still exist
|
||||
assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
|
||||
assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period
|
||||
assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
|
||||
assert (
|
||||
db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1
|
||||
) # Within grace period
|
||||
|
||||
def test_empty_whitelist_deletes_eligible_messages(
|
||||
self, db_session_with_containers, mock_billing_enabled, mock_whitelist
|
||||
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
|
||||
):
|
||||
"""Test that empty whitelist behaves as no whitelist (all eligible messages deleted)."""
|
||||
# Arrange - Create sandbox tenant with expired messages
|
||||
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(tenant, account)
|
||||
conv = self._create_conversation(app)
|
||||
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
|
||||
app = self._create_app(db_session_with_containers, tenant, account)
|
||||
conv = self._create_conversation(db_session_with_containers, app)
|
||||
|
||||
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
|
||||
msg_ids = []
|
||||
for i in range(3):
|
||||
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
|
||||
msg = self._create_message(
|
||||
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i)
|
||||
)
|
||||
msg_ids.append(msg.id)
|
||||
|
||||
# Mock billing service
|
||||
@ -1068,4 +1159,4 @@ class TestMessagesCleanServiceIntegration:
|
||||
assert stats["total_deleted"] == 3
|
||||
|
||||
# Verify all messages were deleted
|
||||
assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0
|
||||
assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -32,7 +33,7 @@ class TestMetadataService:
|
||||
"document_service": mock_document_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -53,18 +54,16 @@ class TestMetadataService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -73,15 +72,17 @@ class TestMetadataService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant):
|
||||
def _create_test_dataset(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset for testing.
|
||||
|
||||
@ -105,14 +106,14 @@ class TestMetadataService:
|
||||
built_in_field_enabled=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account):
|
||||
def _create_test_document(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account
|
||||
):
|
||||
"""
|
||||
Helper method to create a test document for testing.
|
||||
|
||||
@ -141,14 +142,12 @@ class TestMetadataService:
|
||||
doc_language="en",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful metadata creation with valid parameters.
|
||||
"""
|
||||
@ -178,13 +177,14 @@ class TestMetadataService:
|
||||
assert result.created_by == account.id
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.created_at is not None
|
||||
|
||||
def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_metadata_name_too_long(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata creation fails when name exceeds 255 characters.
|
||||
"""
|
||||
@ -207,7 +207,9 @@ class TestMetadataService:
|
||||
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
|
||||
MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_metadata_name_already_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata creation fails when name already exists in the same dataset.
|
||||
"""
|
||||
@ -235,7 +237,7 @@ class TestMetadataService:
|
||||
MetadataService.create_metadata(dataset.id, second_metadata_args)
|
||||
|
||||
def test_create_metadata_name_conflicts_with_built_in_field(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata creation fails when name conflicts with built-in field names.
|
||||
@ -260,7 +262,9 @@ class TestMetadataService:
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
MetadataService.create_metadata(dataset.id, metadata_args)
|
||||
|
||||
def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_metadata_name_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful metadata name update with valid parameters.
|
||||
"""
|
||||
@ -291,12 +295,13 @@ class TestMetadataService:
|
||||
assert result.updated_at is not None
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.name == new_name
|
||||
|
||||
def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_metadata_name_too_long(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata name update fails when new name exceeds 255 characters.
|
||||
"""
|
||||
@ -323,7 +328,9 @@ class TestMetadataService:
|
||||
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name)
|
||||
|
||||
def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_metadata_name_already_exists(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata name update fails when new name already exists in the same dataset.
|
||||
"""
|
||||
@ -351,7 +358,7 @@ class TestMetadataService:
|
||||
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata")
|
||||
|
||||
def test_update_metadata_name_conflicts_with_built_in_field(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata name update fails when new name conflicts with built-in field names.
|
||||
@ -378,7 +385,9 @@ class TestMetadataService:
|
||||
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
|
||||
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
|
||||
|
||||
def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_metadata_name_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata name update fails when metadata ID does not exist.
|
||||
"""
|
||||
@ -406,7 +415,7 @@ class TestMetadataService:
|
||||
# Assert: Verify the method returns None when metadata is not found
|
||||
assert result is None
|
||||
|
||||
def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful metadata deletion with valid parameters.
|
||||
"""
|
||||
@ -434,12 +443,11 @@ class TestMetadataService:
|
||||
assert result.id == metadata.id
|
||||
|
||||
# Verify metadata was deleted from database
|
||||
from extensions.ext_database import db
|
||||
|
||||
deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first()
|
||||
deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
|
||||
assert deleted_metadata is None
|
||||
|
||||
def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test metadata deletion fails when metadata ID does not exist.
|
||||
"""
|
||||
@ -467,7 +475,7 @@ class TestMetadataService:
|
||||
assert result is None
|
||||
|
||||
def test_delete_metadata_with_document_bindings(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata deletion successfully removes document metadata bindings.
|
||||
@ -500,15 +508,13 @@ class TestMetadataService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set document metadata
|
||||
document.doc_metadata = {"test_metadata": "test_value"}
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.delete_metadata(dataset.id, metadata.id)
|
||||
@ -517,13 +523,13 @@ class TestMetadataService:
|
||||
assert result is not None
|
||||
|
||||
# Verify metadata was deleted from database
|
||||
deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first()
|
||||
deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
|
||||
assert deleted_metadata is None
|
||||
|
||||
# Note: The service attempts to update document metadata but may not succeed
|
||||
# due to mock configuration. The main functionality (metadata deletion) is verified.
|
||||
|
||||
def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of built-in metadata fields.
|
||||
"""
|
||||
@ -548,7 +554,9 @@ class TestMetadataService:
|
||||
assert "string" in field_types
|
||||
assert "time" in field_types
|
||||
|
||||
def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_enable_built_in_field_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful enabling of built-in fields for a dataset.
|
||||
"""
|
||||
@ -579,16 +587,15 @@ class TestMetadataService:
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is True
|
||||
|
||||
# Note: Document metadata update depends on DocumentService mock working correctly
|
||||
# The main functionality (enabling built-in fields) is verified
|
||||
|
||||
def test_enable_built_in_field_already_enabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test enabling built-in fields when they are already enabled.
|
||||
@ -607,10 +614,9 @@ class TestMetadataService:
|
||||
|
||||
# Enable built-in fields first
|
||||
dataset.built_in_field_enabled = True
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
|
||||
@ -619,11 +625,11 @@ class TestMetadataService:
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert: Verify the method returns early without changes
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is True
|
||||
|
||||
def test_enable_built_in_field_with_no_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test enabling built-in fields for a dataset with no documents.
|
||||
@ -647,12 +653,13 @@ class TestMetadataService:
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is True
|
||||
|
||||
def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_disable_built_in_field_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful disabling of built-in fields for a dataset.
|
||||
"""
|
||||
@ -673,10 +680,9 @@ class TestMetadataService:
|
||||
|
||||
# Enable built-in fields first
|
||||
dataset.built_in_field_enabled = True
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set document metadata with built-in fields
|
||||
document.doc_metadata = {
|
||||
@ -686,8 +692,8 @@ class TestMetadataService:
|
||||
BuiltInField.last_update_date: 1234567890.0,
|
||||
BuiltInField.source: "test_source",
|
||||
}
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [
|
||||
@ -698,14 +704,14 @@ class TestMetadataService:
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is False
|
||||
|
||||
# Note: Document metadata update depends on DocumentService mock working correctly
|
||||
# The main functionality (disabling built-in fields) is verified
|
||||
|
||||
def test_disable_built_in_field_already_disabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test disabling built-in fields when they are already disabled.
|
||||
@ -732,13 +738,12 @@ class TestMetadataService:
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert: Verify the method returns early without changes
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is False
|
||||
|
||||
def test_disable_built_in_field_with_no_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test disabling built-in fields for a dataset with no documents.
|
||||
@ -757,10 +762,9 @@ class TestMetadataService:
|
||||
|
||||
# Enable built-in fields first
|
||||
dataset.built_in_field_enabled = True
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list
|
||||
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
|
||||
@ -769,10 +773,12 @@ class TestMetadataService:
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
assert dataset.built_in_field_enabled is False
|
||||
|
||||
def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_documents_metadata_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of documents metadata.
|
||||
"""
|
||||
@ -815,24 +821,25 @@ class TestMetadataService:
|
||||
MetadataService.update_documents_metadata(dataset, operation_data)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Verify document metadata was updated
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.doc_metadata is not None
|
||||
assert "test_metadata" in document.doc_metadata
|
||||
assert document.doc_metadata["test_metadata"] == "test_value"
|
||||
|
||||
# Verify metadata binding was created
|
||||
binding = (
|
||||
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first()
|
||||
db_session_with_containers.query(DatasetMetadataBinding)
|
||||
.filter_by(metadata_id=metadata.id, document_id=document.id)
|
||||
.first()
|
||||
)
|
||||
assert binding is not None
|
||||
assert binding.tenant_id == tenant.id
|
||||
assert binding.dataset_id == dataset.id
|
||||
|
||||
def test_update_documents_metadata_with_built_in_fields_enabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update of documents metadata when built-in fields are enabled.
|
||||
@ -850,10 +857,9 @@ class TestMetadataService:
|
||||
|
||||
# Enable built-in fields
|
||||
dataset.built_in_field_enabled = True
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
@ -884,7 +890,7 @@ class TestMetadataService:
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify document metadata was updated with both custom and built-in fields
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.doc_metadata is not None
|
||||
assert "test_metadata" in document.doc_metadata
|
||||
assert document.doc_metadata["test_metadata"] == "test_value"
|
||||
@ -893,7 +899,7 @@ class TestMetadataService:
|
||||
# The main functionality (custom metadata update) is verified
|
||||
|
||||
def test_update_documents_metadata_document_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update of documents metadata when document is not found.
|
||||
@ -936,7 +942,7 @@ class TestMetadataService:
|
||||
MetadataService.update_documents_metadata(dataset, operation_data)
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_dataset_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata lock check for dataset operations.
|
||||
@ -959,7 +965,7 @@ class TestMetadataService:
|
||||
assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}"
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_document_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata lock check for document operations.
|
||||
@ -982,7 +988,7 @@ class TestMetadataService:
|
||||
assert call_args[0][0] == f"document_metadata_lock_{document_id}"
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_lock_exists(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata lock check when lock already exists.
|
||||
@ -999,7 +1005,7 @@ class TestMetadataService:
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_document_lock_exists(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test metadata lock check when document lock already exists.
|
||||
@ -1013,7 +1019,9 @@ class TestMetadataService:
|
||||
with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."):
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, document_id)
|
||||
|
||||
def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_dataset_metadatas_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of dataset metadata information.
|
||||
"""
|
||||
@ -1046,10 +1054,8 @@ class TestMetadataService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
@ -1071,7 +1077,7 @@ class TestMetadataService:
|
||||
assert result["built_in_field_enabled"] is False
|
||||
|
||||
def test_get_dataset_metadatas_with_built_in_fields_enabled(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of dataset metadata when built-in fields are enabled.
|
||||
@ -1086,10 +1092,9 @@ class TestMetadataService:
|
||||
|
||||
# Enable built-in fields
|
||||
dataset.built_in_field_enabled = True
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
|
||||
@ -1114,7 +1119,9 @@ class TestMetadataService:
|
||||
# Verify built-in field status
|
||||
assert result["built_in_field_enabled"] is True
|
||||
|
||||
def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_dataset_metadatas_no_metadata(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of dataset metadata when no metadata exists.
|
||||
"""
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
from models.model import Account, Tenant
|
||||
@ -67,7 +68,7 @@ class TestModelLoadBalancingService:
|
||||
"credential_schema": mock_credential_schema,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -88,18 +89,16 @@ class TestModelLoadBalancingService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -108,8 +107,8 @@ class TestModelLoadBalancingService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
@ -117,7 +116,7 @@ class TestModelLoadBalancingService:
|
||||
return account, tenant
|
||||
|
||||
def _create_test_provider_and_setting(
|
||||
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Helper method to create a test provider and provider model setting.
|
||||
@ -132,8 +131,6 @@ class TestModelLoadBalancingService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create provider
|
||||
provider = Provider(
|
||||
tenant_id=tenant_id,
|
||||
@ -141,8 +138,8 @@ class TestModelLoadBalancingService:
|
||||
provider_type="custom",
|
||||
is_valid=True,
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create provider model setting
|
||||
provider_model_setting = ProviderModelSetting(
|
||||
@ -153,12 +150,14 @@ class TestModelLoadBalancingService:
|
||||
enabled=True,
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
db.session.add(provider_model_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider_model_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return provider, provider_model_setting
|
||||
|
||||
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_enable_model_load_balancing_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful model load balancing enablement.
|
||||
|
||||
@ -193,14 +192,15 @@ class TestModelLoadBalancingService:
|
||||
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(provider)
|
||||
db.session.refresh(provider_model_setting)
|
||||
db_session_with_containers.refresh(provider)
|
||||
db_session_with_containers.refresh(provider_model_setting)
|
||||
assert provider.id is not None
|
||||
assert provider_model_setting.id is not None
|
||||
|
||||
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_disable_model_load_balancing_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful model load balancing disablement.
|
||||
|
||||
@ -235,15 +235,14 @@ class TestModelLoadBalancingService:
|
||||
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(provider)
|
||||
db.session.refresh(provider_model_setting)
|
||||
db_session_with_containers.refresh(provider)
|
||||
db_session_with_containers.refresh(provider_model_setting)
|
||||
assert provider.id is not None
|
||||
assert provider_model_setting.id is not None
|
||||
|
||||
def test_enable_model_load_balancing_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when provider does not exist.
|
||||
@ -275,11 +274,12 @@ class TestModelLoadBalancingService:
|
||||
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
|
||||
|
||||
# Verify no database state changes occurred
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.rollback()
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_load_balancing_configs_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of load balancing configurations.
|
||||
|
||||
@ -298,7 +298,6 @@ class TestModelLoadBalancingService:
|
||||
)
|
||||
|
||||
# Create load balancing config
|
||||
from extensions.ext_database import db
|
||||
|
||||
load_balancing_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant.id,
|
||||
@ -309,11 +308,11 @@ class TestModelLoadBalancingService:
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(load_balancing_config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(load_balancing_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify the config was created
|
||||
db.session.refresh(load_balancing_config)
|
||||
db_session_with_containers.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Setup mocks for get_load_balancing_configs method
|
||||
@ -358,11 +357,11 @@ class TestModelLoadBalancingService:
|
||||
assert configs[0]["ttl"] == 0
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(load_balancing_config)
|
||||
db_session_with_containers.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
def test_get_load_balancing_configs_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when provider does not exist in get_load_balancing_configs.
|
||||
@ -394,12 +393,11 @@ class TestModelLoadBalancingService:
|
||||
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
|
||||
|
||||
# Verify no database state changes occurred
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.rollback()
|
||||
db_session_with_containers.rollback()
|
||||
|
||||
def test_get_load_balancing_configs_with_inherit_config(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test load balancing configs retrieval with inherit configuration.
|
||||
@ -419,7 +417,6 @@ class TestModelLoadBalancingService:
|
||||
)
|
||||
|
||||
# Create load balancing config
|
||||
from extensions.ext_database import db
|
||||
|
||||
load_balancing_config = LoadBalancingModelConfig(
|
||||
tenant_id=tenant.id,
|
||||
@ -430,8 +427,8 @@ class TestModelLoadBalancingService:
|
||||
encrypted_config='{"api_key": "test_key"}',
|
||||
enabled=True,
|
||||
)
|
||||
db.session.add(load_balancing_config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(load_balancing_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks for inherit config scenario
|
||||
mock_provider_config = mock_external_service_dependencies["provider_config"]
|
||||
@ -467,11 +464,11 @@ class TestModelLoadBalancingService:
|
||||
assert configs[1]["name"] == "config1"
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(load_balancing_config)
|
||||
db_session_with_containers.refresh(load_balancing_config)
|
||||
assert load_balancing_config.id is not None
|
||||
|
||||
# Verify inherit config was created in database
|
||||
inherit_configs = db.session.scalars(
|
||||
inherit_configs = db_session_with_containers.scalars(
|
||||
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
|
||||
).all()
|
||||
assert len(inherit_configs) == 1
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
|
||||
@ -29,7 +30,7 @@ class TestModelProviderService:
|
||||
"model_provider_factory": mock_model_provider_factory,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -50,18 +51,16 @@ class TestModelProviderService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -70,8 +69,8 @@ class TestModelProviderService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
@ -80,7 +79,7 @@ class TestModelProviderService:
|
||||
|
||||
def _create_test_provider(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies,
|
||||
tenant_id: str,
|
||||
provider_name: str = "openai",
|
||||
@ -109,16 +108,14 @@ class TestModelProviderService:
|
||||
quota_used=0,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return provider
|
||||
|
||||
def _create_test_provider_model(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
@ -149,16 +146,14 @@ class TestModelProviderService:
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider_model)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider_model)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return provider_model
|
||||
|
||||
def _create_test_provider_model_setting(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
mock_external_service_dependencies,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
@ -190,14 +185,12 @@ class TestModelProviderService:
|
||||
load_balancing_enabled=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider_model_setting)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider_model_setting)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return provider_model_setting
|
||||
|
||||
def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful provider list retrieval.
|
||||
|
||||
@ -275,7 +268,7 @@ class TestModelProviderService:
|
||||
mock_provider_config.is_custom_configuration_available.assert_called_once()
|
||||
|
||||
def test_get_provider_list_with_model_type_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test provider list retrieval with model type filtering.
|
||||
@ -374,7 +367,9 @@ class TestModelProviderService:
|
||||
assert result[0].provider == "cohere"
|
||||
assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types
|
||||
|
||||
def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_models_by_provider_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of models by provider.
|
||||
|
||||
@ -485,7 +480,9 @@ class TestModelProviderService:
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_configurations.get_models.assert_called_once_with(provider="openai")
|
||||
|
||||
def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_provider_credentials_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of provider credentials.
|
||||
|
||||
@ -543,7 +540,7 @@ class TestModelProviderService:
|
||||
mock_method.assert_called_once_with(tenant.id, "openai")
|
||||
|
||||
def test_provider_credentials_validate_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful validation of provider credentials.
|
||||
@ -585,7 +582,7 @@ class TestModelProviderService:
|
||||
mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials)
|
||||
|
||||
def test_provider_credentials_validate_invalid_provider(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test validation failure for non-existent provider.
|
||||
@ -617,7 +614,7 @@ class TestModelProviderService:
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
|
||||
def test_get_default_model_of_model_type_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of default model for a specific model type.
|
||||
@ -673,7 +670,7 @@ class TestModelProviderService:
|
||||
mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM)
|
||||
|
||||
def test_update_default_model_of_model_type_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of default model for a specific model type.
|
||||
@ -706,7 +703,9 @@ class TestModelProviderService:
|
||||
tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4"
|
||||
)
|
||||
|
||||
def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_model_provider_icon_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of model provider icon.
|
||||
|
||||
@ -743,7 +742,9 @@ class TestModelProviderService:
|
||||
# Verify mock interactions
|
||||
mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
|
||||
|
||||
def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_switch_preferred_provider_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful switching of preferred provider type.
|
||||
|
||||
@ -779,7 +780,7 @@ class TestModelProviderService:
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.switch_preferred_provider_type.assert_called_once()
|
||||
|
||||
def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful enabling of a model.
|
||||
|
||||
@ -815,7 +816,9 @@ class TestModelProviderService:
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4")
|
||||
|
||||
def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_model_credentials_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of model credentials.
|
||||
|
||||
@ -872,7 +875,9 @@ class TestModelProviderService:
|
||||
# Verify the method was called with correct parameters
|
||||
mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None)
|
||||
|
||||
def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_model_credentials_validate_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful validation of model credentials.
|
||||
|
||||
@ -914,7 +919,9 @@ class TestModelProviderService:
|
||||
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials
|
||||
)
|
||||
|
||||
def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_model_credentials_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful saving of model credentials.
|
||||
|
||||
@ -955,7 +962,9 @@ class TestModelProviderService:
|
||||
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname"
|
||||
)
|
||||
|
||||
def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_remove_model_credentials_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful removal of model credentials.
|
||||
|
||||
@ -993,7 +1002,9 @@ class TestModelProviderService:
|
||||
model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6"
|
||||
)
|
||||
|
||||
def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_models_by_model_type_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of models by model type.
|
||||
|
||||
@ -1070,7 +1081,9 @@ class TestModelProviderService:
|
||||
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
|
||||
mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
|
||||
def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_model_parameter_rules_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of model parameter rules.
|
||||
|
||||
@ -1137,7 +1150,7 @@ class TestModelProviderService:
|
||||
)
|
||||
|
||||
def test_get_model_parameter_rules_no_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parameter rules retrieval when no credentials are available.
|
||||
@ -1181,7 +1194,7 @@ class TestModelProviderService:
|
||||
)
|
||||
|
||||
def test_get_model_parameter_rules_provider_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parameter rules retrieval when provider does not exist.
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.model import EndUser, Message
|
||||
from models.web import SavedMessage
|
||||
@ -38,7 +39,7 @@ class TestSavedMessageService:
|
||||
"message_service": mock_message_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -85,7 +86,7 @@ class TestSavedMessageService:
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, app):
|
||||
def _create_test_end_user(self, db_session_with_containers: Session, app):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
@ -108,14 +109,12 @@ class TestSavedMessageService:
|
||||
is_anonymous=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_message(self, db_session_with_containers, app, user):
|
||||
def _create_test_message(self, db_session_with_containers: Session, app, user):
|
||||
"""
|
||||
Helper method to create a test message for testing.
|
||||
|
||||
@ -143,10 +142,8 @@ class TestSavedMessageService:
|
||||
mode="chat",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create message
|
||||
message = Message(
|
||||
@ -168,13 +165,13 @@ class TestSavedMessageService:
|
||||
status="success",
|
||||
)
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return message
|
||||
|
||||
def test_pagination_by_last_id_success_with_account_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with account user.
|
||||
@ -207,10 +204,8 @@ class TestSavedMessageService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all([saved_message1, saved_message2])
|
||||
db.session.commit()
|
||||
db_session_with_containers.add_all([saved_message1, saved_message2])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock MessageService.pagination_by_last_id return value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
@ -240,15 +235,15 @@ class TestSavedMessageService:
|
||||
assert actual_include_ids == expected_include_ids
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message1)
|
||||
db.session.refresh(saved_message2)
|
||||
db_session_with_containers.refresh(saved_message1)
|
||||
db_session_with_containers.refresh(saved_message2)
|
||||
assert saved_message1.id is not None
|
||||
assert saved_message2.id is not None
|
||||
assert saved_message1.created_by_role == "account"
|
||||
assert saved_message2.created_by_role == "account"
|
||||
|
||||
def test_pagination_by_last_id_success_with_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with end user.
|
||||
@ -282,10 +277,8 @@ class TestSavedMessageService:
|
||||
created_by=end_user.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all([saved_message1, saved_message2])
|
||||
db.session.commit()
|
||||
db_session_with_containers.add_all([saved_message1, saved_message2])
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock MessageService.pagination_by_last_id return value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
@ -317,14 +310,16 @@ class TestSavedMessageService:
|
||||
assert actual_include_ids == expected_include_ids
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message1)
|
||||
db.session.refresh(saved_message2)
|
||||
db_session_with_containers.refresh(saved_message1)
|
||||
db_session_with_containers.refresh(saved_message2)
|
||||
assert saved_message1.id is not None
|
||||
assert saved_message2.id is not None
|
||||
assert saved_message1.created_by_role == "end_user"
|
||||
assert saved_message2.created_by_role == "end_user"
|
||||
|
||||
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_success_with_new_message(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful save of a new message.
|
||||
|
||||
@ -347,10 +342,9 @@ class TestSavedMessageService:
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was created in database
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -373,10 +367,12 @@ class TestSavedMessageService:
|
||||
)
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(saved_message)
|
||||
db_session_with_containers.refresh(saved_message)
|
||||
assert saved_message.id is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_last_id_error_no_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
@ -396,12 +392,11 @@ class TestSavedMessageService:
|
||||
assert "User is required" in str(exc_info.value)
|
||||
|
||||
# Verify no database operations were performed
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_messages = db.session.query(SavedMessage).all()
|
||||
saved_messages = db_session_with_containers.query(SavedMessage).all()
|
||||
assert len(saved_messages) == 0
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
@ -422,10 +417,9 @@ class TestSavedMessageService:
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -435,7 +429,9 @@ class TestSavedMessageService:
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_success_existing_message(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
@ -457,14 +453,12 @@ class TestSavedMessageService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(saved_message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
assert (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -481,7 +475,7 @@ class TestSavedMessageService:
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -494,11 +488,13 @@ class TestSavedMessageService:
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db.session.query(Message).where(Message.id == message.id).first() is not None
|
||||
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
|
||||
|
||||
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_last_id_error_no_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when no user is provided.
|
||||
|
||||
@ -522,7 +518,7 @@ class TestSavedMessageService:
|
||||
# Instead, we verify that the error was properly raised
|
||||
pass
|
||||
|
||||
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when saving message with no user.
|
||||
|
||||
@ -543,10 +539,9 @@ class TestSavedMessageService:
|
||||
assert result is None
|
||||
|
||||
# Verify no saved message was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -556,7 +551,9 @@ class TestSavedMessageService:
|
||||
|
||||
assert saved_message is None
|
||||
|
||||
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_success_existing_message(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful deletion of an existing saved message.
|
||||
|
||||
@ -578,14 +575,12 @@ class TestSavedMessageService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(saved_message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(saved_message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify saved message exists
|
||||
assert (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -602,7 +597,7 @@ class TestSavedMessageService:
|
||||
# Assert: Verify the expected outcomes
|
||||
# Check if saved message was deleted from database
|
||||
deleted_saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
db_session_with_containers.query(SavedMessage)
|
||||
.where(
|
||||
SavedMessage.app_id == app.id,
|
||||
SavedMessage.message_id == message.id,
|
||||
@ -615,6 +610,6 @@ class TestSavedMessageService:
|
||||
assert deleted_saved_message is None
|
||||
|
||||
# Verify database state
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
# The message should still exist, only the saved_message should be deleted
|
||||
assert db.session.query(Message).where(Message.id == message.id).first() is not None
|
||||
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import create_autospec, patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@ -29,7 +30,7 @@ class TestTagService:
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -50,18 +51,16 @@ class TestTagService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -70,8 +69,8 @@ class TestTagService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
@ -82,7 +81,7 @@ class TestTagService:
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id):
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id):
|
||||
"""
|
||||
Helper method to create a test dataset for testing.
|
||||
|
||||
@ -107,14 +106,12 @@ class TestTagService:
|
||||
created_by=mock_external_service_dependencies["current_user"].id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id):
|
||||
def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id):
|
||||
"""
|
||||
Helper method to create a test app for testing.
|
||||
|
||||
@ -141,15 +138,13 @@ class TestTagService:
|
||||
created_by=mock_external_service_dependencies["current_user"].id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app
|
||||
|
||||
def _create_test_tags(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3
|
||||
):
|
||||
"""
|
||||
Helper method to create test tags for testing.
|
||||
@ -176,16 +171,14 @@ class TestTagService:
|
||||
)
|
||||
tags.append(tag)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for tag in tags:
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tag)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tags
|
||||
|
||||
def _create_test_tag_bindings(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id
|
||||
):
|
||||
"""
|
||||
Helper method to create test tag bindings for testing.
|
||||
@ -211,15 +204,13 @@ class TestTagService:
|
||||
)
|
||||
tag_bindings.append(tag_binding)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for tag_binding in tag_bindings:
|
||||
db.session.add(tag_binding)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tag_binding)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return tag_bindings
|
||||
|
||||
def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of tags with binding count.
|
||||
|
||||
@ -270,7 +261,9 @@ class TestTagService:
|
||||
# The ordering is handled by the database, we just verify the results are returned
|
||||
assert len(result) == 3
|
||||
|
||||
def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tags_with_keyword_filter(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag retrieval with keyword filtering.
|
||||
|
||||
@ -291,12 +284,11 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Update tag names to make them searchable
|
||||
from extensions.ext_database import db
|
||||
|
||||
tags[0].name = "python_development"
|
||||
tags[1].name = "machine_learning"
|
||||
tags[2].name = "web_development"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test with keyword filter
|
||||
result = TagService.get_tags("app", tenant.id, keyword="development")
|
||||
@ -314,7 +306,7 @@ class TestTagService:
|
||||
assert len(result_no_match) == 0
|
||||
|
||||
def test_get_tags_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
|
||||
@ -330,8 +322,6 @@ class TestTagService:
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create tags with special characters in names
|
||||
tag_with_percent = Tag(
|
||||
name="50% discount",
|
||||
@ -340,7 +330,7 @@ class TestTagService:
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_percent.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_percent)
|
||||
db_session_with_containers.add(tag_with_percent)
|
||||
|
||||
tag_with_underscore = Tag(
|
||||
name="test_data_tag",
|
||||
@ -349,7 +339,7 @@ class TestTagService:
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_underscore.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_underscore)
|
||||
db_session_with_containers.add(tag_with_underscore)
|
||||
|
||||
tag_with_backslash = Tag(
|
||||
name="path\\to\\tag",
|
||||
@ -358,7 +348,7 @@ class TestTagService:
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_with_backslash.id = str(uuid.uuid4())
|
||||
db.session.add(tag_with_backslash)
|
||||
db_session_with_containers.add(tag_with_backslash)
|
||||
|
||||
# Create tag that should NOT match
|
||||
tag_no_match = Tag(
|
||||
@ -368,9 +358,9 @@ class TestTagService:
|
||||
created_by=account.id,
|
||||
)
|
||||
tag_no_match.id = str(uuid.uuid4())
|
||||
db.session.add(tag_no_match)
|
||||
db_session_with_containers.add(tag_no_match)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Test 1 - Search with % character
|
||||
result = TagService.get_tags("app", tenant.id, keyword="50%")
|
||||
@ -392,7 +382,7 @@ class TestTagService:
|
||||
assert len(result) == 1
|
||||
assert all("50%" in item.name for item in result)
|
||||
|
||||
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tag retrieval when no tags exist.
|
||||
|
||||
@ -414,7 +404,9 @@ class TestTagService:
|
||||
assert len(result) == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_target_ids_by_tag_ids_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of target IDs by tag IDs.
|
||||
|
||||
@ -469,7 +461,7 @@ class TestTagService:
|
||||
assert second_dataset_count == 1
|
||||
|
||||
def test_get_target_ids_by_tag_ids_empty_tag_ids(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test target ID retrieval with empty tag IDs list.
|
||||
@ -493,7 +485,7 @@ class TestTagService:
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_target_ids_by_tag_ids_no_matching_tags(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test target ID retrieval when no tags match the criteria.
|
||||
@ -521,7 +513,7 @@ class TestTagService:
|
||||
assert len(result) == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of tags by tag name.
|
||||
|
||||
@ -542,11 +534,10 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Update tag names to make them searchable
|
||||
from extensions.ext_database import db
|
||||
|
||||
tags[0].name = "python_tag"
|
||||
tags[1].name = "ml_tag"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
|
||||
@ -558,7 +549,9 @@ class TestTagService:
|
||||
assert result[0].type == "app"
|
||||
assert result[0].tenant_id == tenant.id
|
||||
|
||||
def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tag_by_tag_name_no_matches(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag retrieval by name when no matches exist.
|
||||
|
||||
@ -580,7 +573,9 @@ class TestTagService:
|
||||
assert len(result) == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tag_by_tag_name_empty_parameters(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag retrieval by name with empty parameters.
|
||||
|
||||
@ -605,7 +600,9 @@ class TestTagService:
|
||||
assert result_empty_name is not None
|
||||
assert len(result_empty_name) == 0
|
||||
|
||||
def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tags_by_target_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of tags by target ID.
|
||||
|
||||
@ -644,7 +641,9 @@ class TestTagService:
|
||||
assert tag.tenant_id == tenant.id
|
||||
assert tag.id in [t.id for t in tags]
|
||||
|
||||
def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tags_by_target_id_no_bindings(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag retrieval by target ID when no tags are bound.
|
||||
|
||||
@ -669,7 +668,7 @@ class TestTagService:
|
||||
assert len(result) == 0
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful tag creation.
|
||||
|
||||
@ -698,17 +697,18 @@ class TestTagService:
|
||||
assert result.id is not None
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
|
||||
# Verify tag was actually saved to database
|
||||
saved_tag = db.session.query(Tag).where(Tag.id == result.id).first()
|
||||
saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first()
|
||||
assert saved_tag is not None
|
||||
assert saved_tag.name == "test_tag_name"
|
||||
|
||||
def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_tags_duplicate_name_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag creation with duplicate name.
|
||||
|
||||
@ -731,7 +731,7 @@ class TestTagService:
|
||||
TagService.save_tags(tag_args)
|
||||
assert "Tag name already exists" in str(exc_info.value)
|
||||
|
||||
def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful tag update.
|
||||
|
||||
@ -763,17 +763,16 @@ class TestTagService:
|
||||
assert result.id == tag.id
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.name == "updated_name"
|
||||
|
||||
# Verify tag was actually updated in database
|
||||
updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first()
|
||||
updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
|
||||
assert updated_tag is not None
|
||||
assert updated_tag.name == "updated_name"
|
||||
|
||||
def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tag update for non-existent tag.
|
||||
|
||||
@ -799,7 +798,9 @@ class TestTagService:
|
||||
TagService.update_tags(update_args, non_existent_tag_id)
|
||||
assert "Tag not found" in str(exc_info.value)
|
||||
|
||||
def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_tags_duplicate_name_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag update with duplicate name.
|
||||
|
||||
@ -828,7 +829,9 @@ class TestTagService:
|
||||
TagService.update_tags(update_args, tag2.id)
|
||||
assert "Tag name already exists" in str(exc_info.value)
|
||||
|
||||
def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tag_binding_count_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of tag binding count.
|
||||
|
||||
@ -863,7 +866,7 @@ class TestTagService:
|
||||
assert result_tag_without_bindings == 0
|
||||
|
||||
def test_get_tag_binding_count_non_existent_tag(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test binding count retrieval for non-existent tag.
|
||||
@ -889,7 +892,7 @@ class TestTagService:
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == 0
|
||||
|
||||
def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful tag deletion.
|
||||
|
||||
@ -916,12 +919,11 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Verify tag and binding exist before deletion
|
||||
from extensions.ext_database import db
|
||||
|
||||
tag_before = db.session.query(Tag).where(Tag.id == tag.id).first()
|
||||
tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
|
||||
assert tag_before is not None
|
||||
|
||||
binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
|
||||
binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
|
||||
assert binding_before is not None
|
||||
|
||||
# Act: Execute the method under test
|
||||
@ -929,14 +931,14 @@ class TestTagService:
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify tag was deleted
|
||||
tag_after = db.session.query(Tag).where(Tag.id == tag.id).first()
|
||||
tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
|
||||
assert tag_after is None
|
||||
|
||||
# Verify tag binding was deleted
|
||||
binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
|
||||
binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
|
||||
assert binding_after is None
|
||||
|
||||
def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test tag deletion for non-existent tag.
|
||||
|
||||
@ -960,7 +962,7 @@ class TestTagService:
|
||||
TagService.delete_tag(non_existent_tag_id)
|
||||
assert "Tag not found" in str(exc_info.value)
|
||||
|
||||
def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful tag binding creation.
|
||||
|
||||
@ -988,12 +990,11 @@ class TestTagService:
|
||||
TagService.save_tag_binding(binding_args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Verify tag bindings were created
|
||||
for tag in tags:
|
||||
binding = (
|
||||
db.session.query(TagBinding)
|
||||
db_session_with_containers.query(TagBinding)
|
||||
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
|
||||
.first()
|
||||
)
|
||||
@ -1001,7 +1002,9 @@ class TestTagService:
|
||||
assert binding.tenant_id == tenant.id
|
||||
assert binding.created_by == account.id
|
||||
|
||||
def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_tag_binding_duplicate_handling(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag binding creation with duplicate bindings.
|
||||
|
||||
@ -1032,15 +1035,16 @@ class TestTagService:
|
||||
TagService.save_tag_binding(binding_args)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Verify only one binding exists
|
||||
bindings = db.session.scalars(
|
||||
bindings = db_session_with_containers.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
|
||||
).all()
|
||||
assert len(bindings) == 1
|
||||
|
||||
def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_save_tag_binding_invalid_target_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag binding creation with invalid target type.
|
||||
|
||||
@ -1071,7 +1075,7 @@ class TestTagService:
|
||||
TagService.save_tag_binding(binding_args)
|
||||
assert "Invalid binding type" in str(exc_info.value)
|
||||
|
||||
def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful tag binding deletion.
|
||||
|
||||
@ -1098,10 +1102,11 @@ class TestTagService:
|
||||
)
|
||||
|
||||
# Verify binding exists before deletion
|
||||
from extensions.ext_database import db
|
||||
|
||||
binding_before = (
|
||||
db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first()
|
||||
db_session_with_containers.query(TagBinding)
|
||||
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
|
||||
.first()
|
||||
)
|
||||
assert binding_before is not None
|
||||
|
||||
@ -1112,12 +1117,14 @@ class TestTagService:
|
||||
# Assert: Verify the expected outcomes
|
||||
# Verify tag binding was deleted
|
||||
binding_after = (
|
||||
db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first()
|
||||
db_session_with_containers.query(TagBinding)
|
||||
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
|
||||
.first()
|
||||
)
|
||||
assert binding_after is None
|
||||
|
||||
def test_delete_tag_binding_non_existent_binding(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tag binding deletion for non-existent binding.
|
||||
@ -1145,15 +1152,14 @@ class TestTagService:
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# No error should be raised, and database state should remain unchanged
|
||||
from extensions.ext_database import db
|
||||
|
||||
bindings = db.session.scalars(
|
||||
bindings = db_session_with_containers.scalars(
|
||||
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
|
||||
).all()
|
||||
assert len(bindings) == 0
|
||||
|
||||
def test_check_target_exists_knowledge_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful target existence check for knowledge type.
|
||||
@ -1179,7 +1185,7 @@ class TestTagService:
|
||||
# No exception should be raised for existing dataset
|
||||
|
||||
def test_check_target_exists_knowledge_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test target existence check for non-existent knowledge dataset.
|
||||
@ -1204,7 +1210,9 @@ class TestTagService:
|
||||
TagService.check_target_exists("knowledge", non_existent_dataset_id)
|
||||
assert "Dataset not found" in str(exc_info.value)
|
||||
|
||||
def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_check_target_exists_app_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful target existence check for app type.
|
||||
|
||||
@ -1228,7 +1236,9 @@ class TestTagService:
|
||||
# Assert: Verify the expected outcomes
|
||||
# No exception should be raised for existing app
|
||||
|
||||
def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_check_target_exists_app_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test target existence check for non-existent app.
|
||||
|
||||
@ -1252,7 +1262,9 @@ class TestTagService:
|
||||
TagService.check_target_exists("app", non_existent_app_id)
|
||||
assert "App not found" in str(exc_info.value)
|
||||
|
||||
def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_check_target_exists_invalid_type(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test target existence check for invalid type.
|
||||
|
||||
|
||||
@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
|
||||
from extensions.ext_database import db
|
||||
from models.provider_ids import TriggerProviderID
|
||||
from models.trigger import TriggerSubscription
|
||||
from services.trigger.trigger_provider_service import TriggerProviderService
|
||||
@ -47,7 +47,7 @@ class TestTriggerProviderService:
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -84,7 +84,7 @@ class TestTriggerProviderService:
|
||||
|
||||
def _create_test_subscription(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
tenant_id,
|
||||
user_id,
|
||||
provider_id,
|
||||
@ -135,14 +135,14 @@ class TestTriggerProviderService:
|
||||
expires_at=-1,
|
||||
)
|
||||
|
||||
db.session.add(subscription)
|
||||
db.session.commit()
|
||||
db.session.refresh(subscription)
|
||||
db_session_with_containers.add(subscription)
|
||||
db_session_with_containers.commit()
|
||||
db_session_with_containers.refresh(subscription)
|
||||
|
||||
return subscription
|
||||
|
||||
def test_rebuild_trigger_subscription_success_with_merged_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
|
||||
@ -217,7 +217,7 @@ class TestTriggerProviderService:
|
||||
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(subscription)
|
||||
db_session_with_containers.refresh(subscription)
|
||||
assert subscription.name == "updated_name"
|
||||
assert subscription.parameters == {"param1": "updated_value"}
|
||||
|
||||
@ -244,7 +244,7 @@ class TestTriggerProviderService:
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_new_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are new (no HIDDEN_VALUE).
|
||||
@ -304,7 +304,7 @@ class TestTriggerProviderService:
|
||||
assert subscribe_credentials["api_secret"] == "completely-new-secret"
|
||||
|
||||
def test_rebuild_trigger_subscription_with_all_hidden_values(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
|
||||
@ -363,7 +363,7 @@ class TestTriggerProviderService:
|
||||
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
|
||||
|
||||
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
|
||||
@ -422,7 +422,7 @@ class TestTriggerProviderService:
|
||||
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
|
||||
|
||||
def test_rebuild_trigger_subscription_rollback_on_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that transaction is rolled back on error.
|
||||
@ -470,12 +470,12 @@ class TestTriggerProviderService:
|
||||
)
|
||||
|
||||
# Verify subscription state was not changed (rolled back)
|
||||
db.session.refresh(subscription)
|
||||
db_session_with_containers.refresh(subscription)
|
||||
assert subscription.name == original_name
|
||||
assert subscription.parameters == original_parameters
|
||||
|
||||
def test_rebuild_trigger_subscription_subscription_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error when subscription is not found.
|
||||
@ -501,7 +501,7 @@ class TestTriggerProviderService:
|
||||
)
|
||||
|
||||
def test_rebuild_trigger_subscription_name_uniqueness_check(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that name uniqueness is checked when updating name.
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from models import Account
|
||||
@ -45,7 +46,7 @@ class TestWebConversationService:
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -90,7 +91,7 @@ class TestWebConversationService:
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_end_user(self, db_session_with_containers, app):
|
||||
def _create_test_end_user(self, db_session_with_containers: Session, app):
|
||||
"""
|
||||
Helper method to create a test end user for testing.
|
||||
|
||||
@ -111,14 +112,12 @@ class TestWebConversationService:
|
||||
tenant_id=app.tenant_id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
def _create_test_conversation(self, db_session_with_containers, app, user, fake):
|
||||
def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake):
|
||||
"""
|
||||
Helper method to create a test conversation for testing.
|
||||
|
||||
@ -152,14 +151,14 @@ class TestWebConversationService:
|
||||
is_deleted=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pagination_by_last_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination by last ID with basic parameters.
|
||||
"""
|
||||
@ -194,7 +193,7 @@ class TestWebConversationService:
|
||||
assert result.data[1].updated_at >= result.data[2].updated_at
|
||||
|
||||
def test_pagination_by_last_id_with_pinned_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with pinned conversation filter.
|
||||
@ -222,11 +221,9 @@ class TestWebConversationService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(pinned_conversation1)
|
||||
db.session.add(pinned_conversation2)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(pinned_conversation1)
|
||||
db_session_with_containers.add(pinned_conversation2)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Test pagination with pinned filter
|
||||
result = WebConversationService.pagination_by_last_id(
|
||||
@ -251,7 +248,7 @@ class TestWebConversationService:
|
||||
assert set(returned_ids) == set(expected_ids)
|
||||
|
||||
def test_pagination_by_last_id_with_unpinned_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination by last ID with unpinned conversation filter.
|
||||
@ -273,10 +270,8 @@ class TestWebConversationService:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(pinned_conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(pinned_conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Test pagination with unpinned filter
|
||||
result = WebConversationService.pagination_by_last_id(
|
||||
@ -303,7 +298,7 @@ class TestWebConversationService:
|
||||
expected_unpinned_ids = [conv.id for conv in conversations[1:]]
|
||||
assert set(returned_ids) == set(expected_unpinned_ids)
|
||||
|
||||
def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful pinning of a conversation.
|
||||
"""
|
||||
@ -317,10 +312,9 @@ class TestWebConversationService:
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify the conversation was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -336,7 +330,9 @@ class TestWebConversationService:
|
||||
assert pinned_conversation.created_by_role == "account"
|
||||
assert pinned_conversation.created_by == account.id
|
||||
|
||||
def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pin_conversation_already_pinned(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pinning a conversation that is already pinned (should not create duplicate).
|
||||
"""
|
||||
@ -353,9 +349,8 @@ class TestWebConversationService:
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify only one pinned conversation record exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversations = db.session.scalars(
|
||||
pinned_conversations = db_session_with_containers.scalars(
|
||||
select(PinnedConversation).where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -366,7 +361,9 @@ class TestWebConversationService:
|
||||
|
||||
assert len(pinned_conversations) == 1
|
||||
|
||||
def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pin_conversation_with_end_user(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pinning a conversation with an end user.
|
||||
"""
|
||||
@ -383,10 +380,9 @@ class TestWebConversationService:
|
||||
WebConversationService.pin(app, conversation.id, end_user)
|
||||
|
||||
# Verify the conversation was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -402,7 +398,7 @@ class TestWebConversationService:
|
||||
assert pinned_conversation.created_by_role == "end_user"
|
||||
assert pinned_conversation.created_by == end_user.id
|
||||
|
||||
def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful unpinning of a conversation.
|
||||
"""
|
||||
@ -416,10 +412,9 @@ class TestWebConversationService:
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify it was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -436,7 +431,7 @@ class TestWebConversationService:
|
||||
|
||||
# Verify it was unpinned
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -448,7 +443,9 @@ class TestWebConversationService:
|
||||
|
||||
assert pinned_conversation is None
|
||||
|
||||
def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_unpin_conversation_not_pinned(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test unpinning a conversation that is not pinned (should not cause error).
|
||||
"""
|
||||
@ -462,10 +459,9 @@ class TestWebConversationService:
|
||||
WebConversationService.unpin(app, conversation.id, account)
|
||||
|
||||
# Verify no pinned conversation record exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -478,7 +474,7 @@ class TestWebConversationService:
|
||||
assert pinned_conversation is None
|
||||
|
||||
def test_pagination_by_last_id_user_required_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that pagination_by_last_id raises ValueError when user is None.
|
||||
@ -499,7 +495,7 @@ class TestWebConversationService:
|
||||
sort_by="-updated_at",
|
||||
)
|
||||
|
||||
def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test that pin method returns early when user is None.
|
||||
"""
|
||||
@ -513,10 +509,9 @@ class TestWebConversationService:
|
||||
WebConversationService.pin(app, conversation.id, None)
|
||||
|
||||
# Verify no pinned conversation was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -526,7 +521,9 @@ class TestWebConversationService:
|
||||
|
||||
assert pinned_conversation is None
|
||||
|
||||
def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_unpin_conversation_user_none(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that unpin method returns early when user is None.
|
||||
"""
|
||||
@ -540,10 +537,9 @@ class TestWebConversationService:
|
||||
WebConversationService.pin(app, conversation.id, account)
|
||||
|
||||
# Verify it was pinned
|
||||
from extensions.ext_database import db
|
||||
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
@ -560,7 +556,7 @@ class TestWebConversationService:
|
||||
|
||||
# Verify the conversation is still pinned
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
db_session_with_containers.query(PinnedConversation)
|
||||
.where(
|
||||
PinnedConversation.app_id == app.id,
|
||||
PinnedConversation.conversation_id == conversation.id,
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from libs.password import hash_password
|
||||
@ -45,7 +46,7 @@ class TestWebAppAuthService:
|
||||
"enterprise_service": mock_enterprise_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -68,18 +69,16 @@ class TestWebAppAuthService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -88,15 +87,17 @@ class TestWebAppAuthService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_with_password(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Helper method to create a test account with password for testing.
|
||||
|
||||
@ -131,18 +132,16 @@ class TestWebAppAuthService:
|
||||
account.password = base64.b64encode(password_hash).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -151,15 +150,17 @@ class TestWebAppAuthService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant, password
|
||||
|
||||
def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant):
|
||||
def _create_test_app_and_site(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant
|
||||
):
|
||||
"""
|
||||
Helper method to create a test app and site for testing.
|
||||
|
||||
@ -188,10 +189,8 @@ class TestWebAppAuthService:
|
||||
enable_api=True,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create site
|
||||
site = Site(
|
||||
@ -203,12 +202,12 @@ class TestWebAppAuthService:
|
||||
status="normal",
|
||||
customize_token_strategy="not_allow",
|
||||
)
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(site)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app, site
|
||||
|
||||
def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful authentication with valid email and password.
|
||||
|
||||
@ -233,14 +232,15 @@ class TestWebAppAuthService:
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.password is not None
|
||||
assert result.password_salt is not None
|
||||
|
||||
def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_authenticate_account_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test authentication with non-existent email.
|
||||
|
||||
@ -262,7 +262,7 @@ class TestWebAppAuthService:
|
||||
with pytest.raises(AccountNotFoundError):
|
||||
WebAppAuthService.authenticate(non_existent_email, "any_password")
|
||||
|
||||
def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test authentication with banned account.
|
||||
|
||||
@ -292,10 +292,8 @@ class TestWebAppAuthService:
|
||||
account.password = base64.b64encode(password_hash).decode()
|
||||
account.password_salt = base64.b64encode(salt).decode()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(AccountLoginError) as exc_info:
|
||||
@ -303,7 +301,9 @@ class TestWebAppAuthService:
|
||||
|
||||
assert "Account is banned." in str(exc_info.value)
|
||||
|
||||
def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_authenticate_invalid_password(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test authentication with invalid password.
|
||||
|
||||
@ -323,7 +323,7 @@ class TestWebAppAuthService:
|
||||
assert "Invalid email or password." in str(exc_info.value)
|
||||
|
||||
def test_authenticate_account_without_password(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test authentication for account without password.
|
||||
@ -345,10 +345,8 @@ class TestWebAppAuthService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(AccountPasswordError) as exc_info:
|
||||
@ -356,7 +354,7 @@ class TestWebAppAuthService:
|
||||
|
||||
assert "Invalid email or password." in str(exc_info.value)
|
||||
|
||||
def test_login_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful login and JWT token generation.
|
||||
|
||||
@ -388,7 +386,9 @@ class TestWebAppAuthService:
|
||||
assert call_args["auth_type"] == "internal"
|
||||
assert "exp" in call_args
|
||||
|
||||
def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_user_through_email_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful user retrieval through email.
|
||||
|
||||
@ -413,12 +413,13 @@ class TestWebAppAuthService:
|
||||
assert result.status == AccountStatus.ACTIVE
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
|
||||
def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_user_through_email_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test user retrieval with non-existent email.
|
||||
|
||||
@ -435,7 +436,9 @@ class TestWebAppAuthService:
|
||||
# Assert: Verify proper handling
|
||||
assert result is None
|
||||
|
||||
def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_user_through_email_banned(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test user retrieval with banned account.
|
||||
|
||||
@ -456,10 +459,8 @@ class TestWebAppAuthService:
|
||||
status=AccountStatus.BANNED,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
@ -468,7 +469,7 @@ class TestWebAppAuthService:
|
||||
assert "Account is banned." in str(exc_info.value)
|
||||
|
||||
def test_send_email_code_login_email_with_account(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test sending email code login email with account.
|
||||
@ -509,7 +510,7 @@ class TestWebAppAuthService:
|
||||
assert "code" in mail_call_args[1]
|
||||
|
||||
def test_send_email_code_login_email_with_email_only(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test sending email code login email with email only.
|
||||
@ -549,7 +550,7 @@ class TestWebAppAuthService:
|
||||
assert "code" in mail_call_args[1]
|
||||
|
||||
def test_send_email_code_login_email_no_email_provided(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test sending email code login email without providing email.
|
||||
@ -566,7 +567,9 @@ class TestWebAppAuthService:
|
||||
|
||||
assert "Email must be provided." in str(exc_info.value)
|
||||
|
||||
def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_email_code_login_data_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of email code login data.
|
||||
|
||||
@ -593,7 +596,9 @@ class TestWebAppAuthService:
|
||||
"mock_token", "email_code_login"
|
||||
)
|
||||
|
||||
def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_email_code_login_data_no_data(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test email code login data retrieval when no data exists.
|
||||
|
||||
@ -617,7 +622,7 @@ class TestWebAppAuthService:
|
||||
)
|
||||
|
||||
def test_revoke_email_code_login_token_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful revocation of email code login token.
|
||||
@ -636,7 +641,7 @@ class TestWebAppAuthService:
|
||||
"mock_token", "email_code_login"
|
||||
)
|
||||
|
||||
def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful end user creation.
|
||||
|
||||
@ -668,14 +673,15 @@ class TestWebAppAuthService:
|
||||
assert result.external_user_id == "enterpriseuser"
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.created_at is not None
|
||||
assert result.updated_at is not None
|
||||
|
||||
def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_end_user_site_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test end user creation with non-existent site code.
|
||||
|
||||
@ -693,7 +699,9 @@ class TestWebAppAuthService:
|
||||
|
||||
assert "Site not found." in str(exc_info.value)
|
||||
|
||||
def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_end_user_app_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test end user creation when app is not found.
|
||||
|
||||
@ -708,10 +716,8 @@ class TestWebAppAuthService:
|
||||
status="normal",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
site = Site(
|
||||
app_id="00000000-0000-0000-0000-000000000000",
|
||||
@ -722,8 +728,8 @@ class TestWebAppAuthService:
|
||||
status="normal",
|
||||
customize_token_strategy="not_allow",
|
||||
)
|
||||
db.session.add(site)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(site)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
with pytest.raises(NotFound) as exc_info:
|
||||
@ -732,7 +738,7 @@ class TestWebAppAuthService:
|
||||
assert "App not found." in str(exc_info.value)
|
||||
|
||||
def test_is_app_require_permission_check_with_access_mode_private(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement for private access mode.
|
||||
@ -751,7 +757,7 @@ class TestWebAppAuthService:
|
||||
assert result is True
|
||||
|
||||
def test_is_app_require_permission_check_with_access_mode_public(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement for public access mode.
|
||||
@ -770,7 +776,7 @@ class TestWebAppAuthService:
|
||||
assert result is False
|
||||
|
||||
def test_is_app_require_permission_check_with_app_code(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement using app code.
|
||||
@ -796,7 +802,7 @@ class TestWebAppAuthService:
|
||||
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id")
|
||||
|
||||
def test_is_app_require_permission_check_no_parameters(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test permission check requirement with no parameters.
|
||||
@ -814,7 +820,7 @@ class TestWebAppAuthService:
|
||||
assert "Either app_code or app_id must be provided." in str(exc_info.value)
|
||||
|
||||
def test_get_app_auth_type_with_access_mode_public(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app authentication type for public access mode.
|
||||
@ -833,7 +839,7 @@ class TestWebAppAuthService:
|
||||
assert result == WebAppAuthType.PUBLIC
|
||||
|
||||
def test_get_app_auth_type_with_access_mode_private(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app authentication type for private access mode.
|
||||
@ -851,7 +857,9 @@ class TestWebAppAuthService:
|
||||
# Assert: Verify correct result
|
||||
assert result == WebAppAuthType.INTERNAL
|
||||
|
||||
def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_app_auth_type_with_app_code(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app authentication type using app code.
|
||||
|
||||
@ -878,7 +886,9 @@ class TestWebAppAuthService:
|
||||
"enterprise_service"
|
||||
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
|
||||
|
||||
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_app_auth_type_no_parameters(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test app authentication type with no parameters.
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
|
||||
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
|
||||
@ -48,7 +49,7 @@ class TestWorkflowAppService:
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -96,7 +97,7 @@ class TestWorkflowAppService:
|
||||
|
||||
return app, account
|
||||
|
||||
def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test tenant and account for testing.
|
||||
|
||||
@ -126,7 +127,7 @@ class TestWorkflowAppService:
|
||||
|
||||
return tenant, account
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, tenant, account):
|
||||
def _create_test_app(self, db_session_with_containers: Session, tenant, account):
|
||||
"""
|
||||
Helper method to create a test app for testing.
|
||||
|
||||
@ -160,7 +161,7 @@ class TestWorkflowAppService:
|
||||
|
||||
return app
|
||||
|
||||
def _create_test_workflow_data(self, db_session_with_containers, app, account):
|
||||
def _create_test_workflow_data(self, db_session_with_containers: Session, app, account):
|
||||
"""
|
||||
Helper method to create test workflow data for testing.
|
||||
|
||||
@ -174,8 +175,6 @@ class TestWorkflowAppService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -188,8 +187,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
@ -212,8 +211,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow app log
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
@ -227,13 +226,13 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return workflow, workflow_run, workflow_app_log
|
||||
|
||||
def test_get_paginate_workflow_app_logs_basic_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination of workflow app logs with basic parameters.
|
||||
@ -268,13 +267,12 @@ class TestWorkflowAppService:
|
||||
assert log_entry.workflow_run_id == workflow_run.id
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(workflow_app_log)
|
||||
db_session_with_containers.refresh(workflow_app_log)
|
||||
assert workflow_app_log.id is not None
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_keyword_search(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with keyword search functionality.
|
||||
@ -287,11 +285,10 @@ class TestWorkflowAppService:
|
||||
)
|
||||
|
||||
# Update workflow run with searchable content
|
||||
from extensions.ext_database import db
|
||||
|
||||
workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"})
|
||||
workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"})
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test with keyword search
|
||||
service = WorkflowAppService()
|
||||
@ -317,7 +314,7 @@ class TestWorkflowAppService:
|
||||
assert len(result_no_match["data"]) == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
r"""
|
||||
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
|
||||
@ -332,8 +329,6 @@ class TestWorkflowAppService:
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = WorkflowAppService()
|
||||
|
||||
# Test 1: Search with % character
|
||||
@ -353,8 +348,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_1)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(workflow_run_1)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
workflow_app_log_1 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -367,8 +362,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log_1.id = str(uuid.uuid4())
|
||||
workflow_app_log_1.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_1)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log_1)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
@ -395,8 +390,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_2)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(workflow_run_2)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
workflow_app_log_2 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -409,8 +404,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log_2.id = str(uuid.uuid4())
|
||||
workflow_app_log_2.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_2)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log_2)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
|
||||
@ -437,8 +432,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run_4)
|
||||
db.session.flush()
|
||||
db_session_with_containers.add(workflow_run_4)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
workflow_app_log_4 = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -451,8 +446,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log_4.id = str(uuid.uuid4())
|
||||
workflow_app_log_4.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log_4)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log_4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
result = service.get_paginate_workflow_app_logs(
|
||||
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
|
||||
@ -467,7 +462,7 @@ class TestWorkflowAppService:
|
||||
assert workflow_run_4.id not in found_run_ids
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_status_filter(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with status filtering.
|
||||
@ -476,8 +471,6 @@ class TestWorkflowAppService:
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -490,8 +483,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow runs with different statuses
|
||||
statuses = ["succeeded", "failed", "running", "stopped"]
|
||||
@ -519,8 +512,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC) + timedelta(minutes=i),
|
||||
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None,
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -533,8 +526,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_runs.append(workflow_run)
|
||||
workflow_app_logs.append(workflow_app_log)
|
||||
@ -568,7 +561,7 @@ class TestWorkflowAppService:
|
||||
assert result_running["data"][0].workflow_run.status == "running"
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_time_filtering(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with time-based filtering.
|
||||
@ -577,8 +570,6 @@ class TestWorkflowAppService:
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -591,8 +582,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow runs with different timestamps
|
||||
base_time = datetime.now(UTC)
|
||||
@ -627,8 +618,8 @@ class TestWorkflowAppService:
|
||||
created_at=timestamp,
|
||||
finished_at=timestamp + timedelta(minutes=1),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -641,8 +632,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = timestamp
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_runs.append(workflow_run)
|
||||
workflow_app_logs.append(workflow_app_log)
|
||||
@ -682,7 +673,7 @@ class TestWorkflowAppService:
|
||||
assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_pagination(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with different page sizes and limits.
|
||||
@ -691,8 +682,6 @@ class TestWorkflowAppService:
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -705,8 +694,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create 25 workflow runs and logs
|
||||
total_logs = 25
|
||||
@ -734,8 +723,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC) + timedelta(minutes=i),
|
||||
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -748,8 +737,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_runs.append(workflow_run)
|
||||
workflow_app_logs.append(workflow_app_log)
|
||||
@ -798,7 +787,7 @@ class TestWorkflowAppService:
|
||||
assert len(result_large_limit["data"]) == total_logs
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_user_role_filtering(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with user role and session filtering.
|
||||
@ -807,8 +796,6 @@ class TestWorkflowAppService:
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -821,8 +808,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create end user
|
||||
end_user = EndUser(
|
||||
@ -835,8 +822,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow runs and logs for both account and end user
|
||||
workflow_runs = []
|
||||
@ -864,8 +851,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC) + timedelta(minutes=i),
|
||||
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -878,8 +865,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_runs.append(workflow_run)
|
||||
workflow_app_logs.append(workflow_app_log)
|
||||
@ -906,8 +893,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC) + timedelta(minutes=i + 10),
|
||||
finished_at=datetime.now(UTC) + timedelta(minutes=i + 11),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
tenant_id=app.tenant_id,
|
||||
@ -920,8 +907,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_runs.append(workflow_run)
|
||||
workflow_app_logs.append(workflow_app_log)
|
||||
@ -994,7 +981,7 @@ class TestWorkflowAppService:
|
||||
assert "Account not found" in str(exc_info.value)
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_uuid_keyword_search(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with UUID keyword search functionality.
|
||||
@ -1003,8 +990,6 @@ class TestWorkflowAppService:
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -1017,8 +1002,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow run with specific UUID
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
@ -1042,8 +1027,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC) + timedelta(minutes=1),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow app log
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
@ -1057,8 +1042,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Test UUID keyword search
|
||||
service = WorkflowAppService()
|
||||
@ -1085,7 +1070,7 @@ class TestWorkflowAppService:
|
||||
assert result_invalid_uuid["total"] == 0
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_edge_cases(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with edge cases and boundary conditions.
|
||||
@ -1094,8 +1079,6 @@ class TestWorkflowAppService:
|
||||
fake = Faker()
|
||||
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
id=str(uuid.uuid4()),
|
||||
@ -1108,8 +1091,8 @@ class TestWorkflowAppService:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow run with edge case data
|
||||
workflow_run = WorkflowRun(
|
||||
@ -1132,8 +1115,8 @@ class TestWorkflowAppService:
|
||||
created_at=datetime.now(UTC),
|
||||
finished_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow app log
|
||||
workflow_app_log = WorkflowAppLog(
|
||||
@ -1147,8 +1130,8 @@ class TestWorkflowAppService:
|
||||
)
|
||||
workflow_app_log.id = str(uuid.uuid4())
|
||||
workflow_app_log.created_at = datetime.now(UTC)
|
||||
db.session.add(workflow_app_log)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_app_log)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Test edge cases
|
||||
service = WorkflowAppService()
|
||||
@ -1185,7 +1168,7 @@ class TestWorkflowAppService:
|
||||
assert result_high_page["has_more"] is False
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_empty_results(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with empty results and no data scenarios.
|
||||
@ -1252,7 +1235,7 @@ class TestWorkflowAppService:
|
||||
assert "Account not found" in str(exc_info.value)
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with complex query combinations.
|
||||
@ -1352,7 +1335,7 @@ class TestWorkflowAppService:
|
||||
assert len(result_time_status_limit["data"]) <= 2
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_large_dataset_performance(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with large dataset for performance validation.
|
||||
@ -1444,7 +1427,7 @@ class TestWorkflowAppService:
|
||||
assert result_last_page["page"] == 3
|
||||
|
||||
def test_get_paginate_workflow_app_logs_with_tenant_isolation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow app logs pagination with proper tenant isolation.
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from dify_graph.variables.segments import StringSegment
|
||||
@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService:
|
||||
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
|
||||
return {}
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
|
||||
def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None):
|
||||
"""
|
||||
Helper method to create a test app with realistic data for testing.
|
||||
|
||||
@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService:
|
||||
app.created_by = fake.uuid4()
|
||||
app.updated_by = app.created_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
|
||||
def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None):
|
||||
"""
|
||||
Helper method to create a test workflow associated with an app.
|
||||
|
||||
@ -110,15 +109,14 @@ class TestWorkflowDraftVariableService:
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
return workflow
|
||||
|
||||
def _create_test_variable(
|
||||
self,
|
||||
db_session_with_containers,
|
||||
db_session_with_containers: Session,
|
||||
app_id,
|
||||
node_id,
|
||||
name,
|
||||
@ -174,13 +172,12 @@ class TestWorkflowDraftVariableService:
|
||||
visible=True,
|
||||
editable=True,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(variable)
|
||||
db_session_with_containers.commit()
|
||||
return variable
|
||||
|
||||
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a single variable by ID successfully.
|
||||
|
||||
@ -202,7 +199,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert retrieved_variable.app_id == app.id
|
||||
assert retrieved_variable.get_value().value == test_value.value
|
||||
|
||||
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test getting a variable that doesn't exist.
|
||||
|
||||
@ -217,7 +214,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert retrieved_variable is None
|
||||
|
||||
def test_get_draft_variables_by_selectors_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting variables by selectors successfully.
|
||||
@ -268,7 +265,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert var.get_value().value == var3_value.value
|
||||
|
||||
def test_list_variables_without_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test listing variables without values successfully with pagination.
|
||||
@ -300,7 +297,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert var.name is not None
|
||||
assert var.app_id == app.id
|
||||
|
||||
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test listing variables for a specific node successfully.
|
||||
|
||||
@ -352,7 +349,9 @@ class TestWorkflowDraftVariableService:
|
||||
assert "var2" in var_names
|
||||
assert "var3" not in var_names
|
||||
|
||||
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_list_conversation_variables_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test listing conversation variables successfully.
|
||||
|
||||
@ -393,7 +392,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert "conv_var2" in var_names
|
||||
assert "sys_var" not in var_names
|
||||
|
||||
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test updating a variable's name and value successfully.
|
||||
|
||||
@ -418,14 +417,15 @@ class TestWorkflowDraftVariableService:
|
||||
assert updated_variable.name == "new_name"
|
||||
assert updated_variable.get_value().value == new_value.value
|
||||
assert updated_variable.last_edited_at is not None
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(variable)
|
||||
db_session_with_containers.refresh(variable)
|
||||
assert variable.name == "new_name"
|
||||
assert variable.get_value().value == new_value.value
|
||||
assert variable.last_edited_at is not None
|
||||
|
||||
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_variable_not_editable(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that updating a non-editable variable raises an exception.
|
||||
|
||||
@ -445,17 +445,18 @@ class TestWorkflowDraftVariableService:
|
||||
node_execution_id=fake.uuid4(),
|
||||
editable=False, # Set as non-editable
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(variable)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(variable)
|
||||
db_session_with_containers.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
with pytest.raises(UpdateNotSupportedError) as exc_info:
|
||||
service.update_variable(variable, name="new_name", value=new_value)
|
||||
assert "variable not support updating" in str(exc_info.value)
|
||||
assert variable.id in str(exc_info.value)
|
||||
|
||||
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_reset_conversation_variable_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test resetting conversation variable successfully.
|
||||
|
||||
@ -476,9 +477,8 @@ class TestWorkflowDraftVariableService:
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
modified_value = StringSegment(value=fake.word())
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers,
|
||||
@ -489,17 +489,17 @@ class TestWorkflowDraftVariableService:
|
||||
fake=fake,
|
||||
)
|
||||
variable.last_edited_at = fake.date_time()
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
reset_variable = service.reset_variable(workflow, variable)
|
||||
assert reset_variable is not None
|
||||
assert reset_variable.get_value().value == "default_value"
|
||||
assert reset_variable.last_edited_at is None
|
||||
db.session.refresh(variable)
|
||||
db_session_with_containers.refresh(variable)
|
||||
assert variable.get_value().value == "default_value"
|
||||
assert variable.last_edited_at is None
|
||||
|
||||
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test deleting a single variable successfully.
|
||||
|
||||
@ -513,14 +513,15 @@ class TestWorkflowDraftVariableService:
|
||||
variable = self._create_test_variable(
|
||||
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
|
||||
assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_variable(variable)
|
||||
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
|
||||
assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
|
||||
|
||||
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_workflow_variables_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test deleting all variables for a workflow successfully.
|
||||
|
||||
@ -550,20 +551,25 @@ class TestWorkflowDraftVariableService:
|
||||
other_value,
|
||||
fake=fake,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
)
|
||||
assert len(app_variables) == 3
|
||||
assert len(other_app_variables) == 1
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_workflow_variables(app.id)
|
||||
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
|
||||
other_app_variables_after = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
|
||||
)
|
||||
assert len(app_variables_after) == 0
|
||||
assert len(other_app_variables_after) == 1
|
||||
|
||||
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_node_variables_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test deleting all variables for a specific node successfully.
|
||||
|
||||
@ -605,14 +611,15 @@ class TestWorkflowDraftVariableService:
|
||||
conv_value,
|
||||
fake=fake,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
target_node_variables = (
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
)
|
||||
other_node_variables = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
db_session_with_containers.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
@ -622,13 +629,13 @@ class TestWorkflowDraftVariableService:
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.delete_node_variables(app.id, node_id)
|
||||
target_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
|
||||
)
|
||||
other_node_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
|
||||
)
|
||||
conv_variables_after = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
db_session_with_containers.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
@ -637,7 +644,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert len(conv_variables_after) == 1
|
||||
|
||||
def test_prefill_conversation_variable_default_values_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test prefill conversation variable default values successfully.
|
||||
@ -665,13 +672,12 @@ class TestWorkflowDraftVariableService:
|
||||
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
|
||||
)
|
||||
workflow.conversation_variables = [conv_var1, conv_var2]
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
service = WorkflowDraftVariableService(db_session_with_containers)
|
||||
service.prefill_conversation_variable_default_values(workflow)
|
||||
draft_variables = (
|
||||
db.session.query(WorkflowDraftVariable)
|
||||
db_session_with_containers.query(WorkflowDraftVariable)
|
||||
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
|
||||
.all()
|
||||
)
|
||||
@ -686,7 +692,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert var.get_variable_type() == DraftVariableType.CONVERSATION
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID from draft variable successfully.
|
||||
@ -713,7 +719,7 @@ class TestWorkflowDraftVariableService:
|
||||
assert retrieved_conv_id == conversation_id
|
||||
|
||||
def test_get_conversation_id_from_draft_variable_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting conversation ID when it doesn't exist.
|
||||
@ -728,7 +734,9 @@ class TestWorkflowDraftVariableService:
|
||||
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
|
||||
assert retrieved_conv_id is None
|
||||
|
||||
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_list_system_variables_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test listing system variables successfully.
|
||||
|
||||
@ -775,7 +783,9 @@ class TestWorkflowDraftVariableService:
|
||||
assert "sys_var2" in var_names
|
||||
assert "conv_var" not in var_names
|
||||
|
||||
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_variable_by_name_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting variables by name successfully for different types.
|
||||
|
||||
@ -822,7 +832,9 @@ class TestWorkflowDraftVariableService:
|
||||
assert retrieved_node_var.name == "test_node_var"
|
||||
assert retrieved_node_var.node_id == "test_node"
|
||||
|
||||
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_variable_by_name_not_found(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting variables by name when they don't exist.
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import (
|
||||
@ -48,7 +49,7 @@ class TestWorkflowRunService:
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -94,7 +95,7 @@ class TestWorkflowRunService:
|
||||
return app, account
|
||||
|
||||
def _create_test_workflow_run(
|
||||
self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0
|
||||
self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0
|
||||
):
|
||||
"""
|
||||
Helper method to create a test workflow run for testing.
|
||||
@ -110,8 +111,6 @@ class TestWorkflowRunService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create workflow run with offset timestamp
|
||||
base_time = datetime.now(UTC)
|
||||
created_time = base_time - timedelta(minutes=offset_minutes)
|
||||
@ -136,12 +135,12 @@ class TestWorkflowRunService:
|
||||
finished_at=created_time,
|
||||
)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow_run)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _create_test_message(self, db_session_with_containers, app, account, workflow_run):
|
||||
def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run):
|
||||
"""
|
||||
Helper method to create a test message for testing.
|
||||
|
||||
@ -156,8 +155,6 @@ class TestWorkflowRunService:
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Create conversation first (required for message)
|
||||
from models.model import Conversation
|
||||
|
||||
@ -170,8 +167,8 @@ class TestWorkflowRunService:
|
||||
from_source=CreatorUserRole.ACCOUNT,
|
||||
from_account_id=account.id,
|
||||
)
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(conversation)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create message
|
||||
message = Message()
|
||||
@ -193,12 +190,14 @@ class TestWorkflowRunService:
|
||||
message.workflow_run_id = workflow_run.id
|
||||
message.inputs = {"input": "test input"}
|
||||
|
||||
db.session.add(message)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(message)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return message
|
||||
|
||||
def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_paginate_workflow_runs_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination of workflow runs with debugging trigger.
|
||||
|
||||
@ -239,7 +238,7 @@ class TestWorkflowRunService:
|
||||
assert workflow_run.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_paginate_workflow_runs_with_last_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of workflow runs with last_id parameter.
|
||||
@ -282,7 +281,7 @@ class TestWorkflowRunService:
|
||||
assert workflow_run.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_paginate_workflow_runs_default_limit(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test pagination of workflow runs with default limit.
|
||||
@ -320,7 +319,7 @@ class TestWorkflowRunService:
|
||||
assert workflow_run_result.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_paginate_advanced_chat_workflow_runs_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful pagination of advanced chat workflow runs with message information.
|
||||
@ -365,7 +364,7 @@ class TestWorkflowRunService:
|
||||
assert workflow_run.app_id == app.id
|
||||
assert workflow_run.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of workflow run by ID.
|
||||
|
||||
@ -395,7 +394,7 @@ class TestWorkflowRunService:
|
||||
assert result.type == "chat"
|
||||
assert result.version == "1.0.0"
|
||||
|
||||
def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test workflow run retrieval when run ID does not exist.
|
||||
|
||||
@ -419,7 +418,7 @@ class TestWorkflowRunService:
|
||||
assert result is None
|
||||
|
||||
def test_get_workflow_run_node_executions_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of workflow run node executions.
|
||||
@ -438,7 +437,6 @@ class TestWorkflowRunService:
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Create node executions
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
node_executions = []
|
||||
@ -462,7 +460,7 @@ class TestWorkflowRunService:
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(node_execution)
|
||||
db_session_with_containers.add(node_execution)
|
||||
node_executions.append(node_execution)
|
||||
|
||||
paused_node_execution = WorkflowNodeExecutionModel(
|
||||
@ -484,9 +482,9 @@ class TestWorkflowRunService:
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(paused_node_execution)
|
||||
db_session_with_containers.add(paused_node_execution)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
@ -509,7 +507,7 @@ class TestWorkflowRunService:
|
||||
assert node_execution.node_id.startswith("node_")
|
||||
|
||||
def test_get_workflow_run_node_executions_empty(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting node executions for a workflow run with no executions.
|
||||
@ -560,7 +558,7 @@ class TestWorkflowRunService:
|
||||
assert len(result) == 0
|
||||
|
||||
def test_get_workflow_run_node_executions_invalid_workflow_run_id(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting node executions with invalid workflow run ID.
|
||||
@ -611,7 +609,7 @@ class TestWorkflowRunService:
|
||||
assert len(result) == 0
|
||||
|
||||
def test_get_workflow_run_node_executions_database_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test getting node executions when database encounters an error.
|
||||
@ -662,7 +660,7 @@ class TestWorkflowRunService:
|
||||
)
|
||||
|
||||
def test_get_workflow_run_node_executions_end_user(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test node execution retrieval for end user.
|
||||
@ -680,7 +678,6 @@ class TestWorkflowRunService:
|
||||
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
|
||||
|
||||
# Create end user
|
||||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
end_user = EndUser(
|
||||
@ -692,8 +689,8 @@ class TestWorkflowRunService:
|
||||
external_user_id=str(uuid.uuid4()),
|
||||
name=fake.name(),
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(end_user)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create node execution
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
@ -717,8 +714,8 @@ class TestWorkflowRunService:
|
||||
created_by=end_user.id,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(node_execution)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(node_execution)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
workflow_run_service = WorkflowRunService()
|
||||
|
||||
@ -10,6 +10,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, App, Workflow
|
||||
from models.model import AppMode
|
||||
@ -32,7 +33,7 @@ class TestWorkflowService:
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
def _create_test_account(self, db_session_with_containers, fake=None):
|
||||
def _create_test_account(self, db_session_with_containers: Session, fake=None):
|
||||
"""
|
||||
Helper method to create a test account with realistic data.
|
||||
|
||||
@ -67,18 +68,16 @@ class TestWorkflowService:
|
||||
tenant.created_at = fake.date_time_this_year()
|
||||
tenant.updated_at = tenant.created_at
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(tenant)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, fake=None):
|
||||
def _create_test_app(self, db_session_with_containers: Session, fake=None):
|
||||
"""
|
||||
Helper method to create a test app with realistic data.
|
||||
|
||||
@ -106,13 +105,11 @@ class TestWorkflowService:
|
||||
)
|
||||
app.updated_by = app.created_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
return app
|
||||
|
||||
def _create_test_workflow(self, db_session_with_containers, app, account, fake=None):
|
||||
def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None):
|
||||
"""
|
||||
Helper method to create a test workflow associated with an app.
|
||||
|
||||
@ -141,13 +138,11 @@ class TestWorkflowService:
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
return workflow
|
||||
|
||||
def test_get_node_last_run_success(self, db_session_with_containers):
|
||||
def test_get_node_last_run_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful retrieval of the most recent execution for a specific node.
|
||||
|
||||
@ -180,10 +175,8 @@ class TestWorkflowService:
|
||||
node_execution.created_by = account.id # Required field
|
||||
node_execution.created_at = fake.date_time_this_year()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(node_execution)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(node_execution)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -196,7 +189,7 @@ class TestWorkflowService:
|
||||
assert result.workflow_id == workflow.id
|
||||
assert result.status == "succeeded"
|
||||
|
||||
def test_get_node_last_run_not_found(self, db_session_with_containers):
|
||||
def test_get_node_last_run_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval when no execution record exists for the specified node.
|
||||
|
||||
@ -217,7 +210,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_is_workflow_exist_true(self, db_session_with_containers):
|
||||
def test_is_workflow_exist_true(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test workflow existence check when a draft workflow exists.
|
||||
|
||||
@ -238,7 +231,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
def test_is_workflow_exist_false(self, db_session_with_containers):
|
||||
def test_is_workflow_exist_false(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test workflow existence check when no draft workflow exists.
|
||||
|
||||
@ -258,7 +251,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
def test_get_draft_workflow_success(self, db_session_with_containers):
|
||||
def test_get_draft_workflow_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful retrieval of a draft workflow.
|
||||
|
||||
@ -284,7 +277,7 @@ class TestWorkflowService:
|
||||
assert result.app_id == app.id
|
||||
assert result.tenant_id == app.tenant_id
|
||||
|
||||
def test_get_draft_workflow_not_found(self, db_session_with_containers):
|
||||
def test_get_draft_workflow_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test draft workflow retrieval when no draft workflow exists.
|
||||
|
||||
@ -304,7 +297,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_published_workflow_by_id_success(self, db_session_with_containers):
|
||||
def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful retrieval of a published workflow by ID.
|
||||
|
||||
@ -321,9 +314,7 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
workflow.version = "2024.01.01.001" # Published version
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -336,7 +327,7 @@ class TestWorkflowService:
|
||||
assert result.version != Workflow.VERSION_DRAFT
|
||||
assert result.app_id == app.id
|
||||
|
||||
def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers):
|
||||
def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when trying to retrieve a draft workflow as published.
|
||||
|
||||
@ -359,7 +350,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(IsDraftWorkflowError):
|
||||
workflow_service.get_published_workflow_by_id(app, workflow.id)
|
||||
|
||||
def test_get_published_workflow_by_id_not_found(self, db_session_with_containers):
|
||||
def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval when no workflow exists with the specified ID.
|
||||
|
||||
@ -379,7 +370,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_published_workflow_success(self, db_session_with_containers):
|
||||
def test_get_published_workflow_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful retrieval of the current published workflow for an app.
|
||||
|
||||
@ -395,10 +386,8 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
workflow.version = "2024.01.01.001" # Published version
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
app.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -411,7 +400,7 @@ class TestWorkflowService:
|
||||
assert result.version != Workflow.VERSION_DRAFT
|
||||
assert result.app_id == app.id
|
||||
|
||||
def test_get_published_workflow_no_workflow_id(self, db_session_with_containers):
|
||||
def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval when app has no associated workflow ID.
|
||||
|
||||
@ -431,7 +420,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_all_published_workflow_pagination(self, db_session_with_containers):
|
||||
def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test pagination of published workflows.
|
||||
|
||||
@ -455,15 +444,13 @@ class TestWorkflowService:
|
||||
# Set the app's workflow_id to the first workflow
|
||||
app.workflow_id = workflows[0].id
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Act - First page
|
||||
result_workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=db.session,
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
page=1,
|
||||
limit=3,
|
||||
@ -476,7 +463,7 @@ class TestWorkflowService:
|
||||
|
||||
# Act - Second page
|
||||
result_workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=db.session,
|
||||
session=db_session_with_containers,
|
||||
app_model=app,
|
||||
page=2,
|
||||
limit=3,
|
||||
@ -487,7 +474,7 @@ class TestWorkflowService:
|
||||
assert len(result_workflows) == 2
|
||||
assert has_more is False
|
||||
|
||||
def test_get_all_published_workflow_user_filter(self, db_session_with_containers):
|
||||
def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test filtering published workflows by user.
|
||||
|
||||
@ -513,22 +500,20 @@ class TestWorkflowService:
|
||||
# Set the app's workflow_id to the first workflow
|
||||
app.workflow_id = workflow1.id
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Act - Filter by account1
|
||||
result_workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=db.session, app_model=app, page=1, limit=10, user_id=account1.id
|
||||
session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result_workflows) == 1
|
||||
assert result_workflows[0].created_by == account1.id
|
||||
|
||||
def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers):
|
||||
def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test filtering published workflows to show only named workflows.
|
||||
|
||||
@ -557,22 +542,20 @@ class TestWorkflowService:
|
||||
# Set the app's workflow_id to the first workflow
|
||||
app.workflow_id = workflow1.id
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Act - Filter named only
|
||||
result_workflows, has_more = workflow_service.get_all_published_workflow(
|
||||
session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True
|
||||
session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result_workflows) == 2
|
||||
assert all(wf.marked_name for wf in result_workflows)
|
||||
|
||||
def test_sync_draft_workflow_create_new(self, db_session_with_containers):
|
||||
def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test creating a new draft workflow through sync operation.
|
||||
|
||||
@ -624,7 +607,7 @@ class TestWorkflowService:
|
||||
assert result.features == json.dumps(features)
|
||||
assert result.created_by == account.id
|
||||
|
||||
def test_sync_draft_workflow_update_existing(self, db_session_with_containers):
|
||||
def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test updating an existing draft workflow through sync operation.
|
||||
|
||||
@ -688,7 +671,7 @@ class TestWorkflowService:
|
||||
assert result.features == json.dumps(new_features)
|
||||
assert result.updated_by == account.id
|
||||
|
||||
def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers):
|
||||
def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when sync is attempted with mismatched hash.
|
||||
|
||||
@ -738,7 +721,7 @@ class TestWorkflowService:
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
def test_publish_workflow_success(self, db_session_with_containers):
|
||||
def test_publish_workflow_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful workflow publishing.
|
||||
|
||||
@ -755,9 +738,7 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
workflow.version = Workflow.VERSION_DRAFT
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -777,7 +758,7 @@ class TestWorkflowService:
|
||||
assert len(result.version) > 10 # Should be a reasonable timestamp length
|
||||
assert result.created_by == account.id
|
||||
|
||||
def test_publish_workflow_no_draft_error(self, db_session_with_containers):
|
||||
def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when publishing workflow without draft.
|
||||
|
||||
@ -797,7 +778,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(ValueError, match="No valid workflow found"):
|
||||
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
|
||||
|
||||
def test_publish_workflow_already_published_error(self, db_session_with_containers):
|
||||
def test_publish_workflow_already_published_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when publishing already published workflow.
|
||||
|
||||
@ -813,9 +794,7 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
workflow.version = "2024.01.01.001" # Already published
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -823,7 +802,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(ValueError, match="No valid workflow found"):
|
||||
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
|
||||
|
||||
def test_get_default_block_configs(self, db_session_with_containers):
|
||||
def test_get_default_block_configs(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval of default block configurations for all node types.
|
||||
|
||||
@ -847,7 +826,7 @@ class TestWorkflowService:
|
||||
assert isinstance(config, dict)
|
||||
# The structure can vary, so we just check it's a dict
|
||||
|
||||
def test_get_default_block_config_specific_type(self, db_session_with_containers):
|
||||
def test_get_default_block_config_specific_type(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval of default block configuration for a specific node type.
|
||||
|
||||
@ -867,7 +846,7 @@ class TestWorkflowService:
|
||||
# This is acceptable behavior
|
||||
assert result is None or isinstance(result, dict)
|
||||
|
||||
def test_get_default_block_config_invalid_type(self, db_session_with_containers):
|
||||
def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval of default block configuration for invalid node type.
|
||||
|
||||
@ -887,7 +866,7 @@ class TestWorkflowService:
|
||||
# It's also acceptable for the service to raise a ValueError for invalid types
|
||||
pass
|
||||
|
||||
def test_get_default_block_config_with_filters(self, db_session_with_containers):
|
||||
def test_get_default_block_config_with_filters(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test retrieval of default block configuration with filters.
|
||||
|
||||
@ -907,7 +886,7 @@ class TestWorkflowService:
|
||||
# Result might be None if filters don't match, but should not raise error
|
||||
assert result is None or isinstance(result, dict)
|
||||
|
||||
def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers):
|
||||
def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful conversion from chat mode app to workflow mode.
|
||||
|
||||
@ -944,11 +923,9 @@ class TestWorkflowService:
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app_model_config)
|
||||
db_session_with_containers.add(app_model_config)
|
||||
app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
conversion_args = {
|
||||
@ -969,7 +946,7 @@ class TestWorkflowService:
|
||||
assert result.icon_type == conversion_args["icon_type"]
|
||||
assert result.icon_background == conversion_args["icon_background"]
|
||||
|
||||
def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers):
|
||||
def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful conversion from completion mode app to workflow mode.
|
||||
|
||||
@ -1006,11 +983,9 @@ class TestWorkflowService:
|
||||
)
|
||||
app_model_config.id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app_model_config)
|
||||
db_session_with_containers.add(app_model_config)
|
||||
app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
conversion_args = {
|
||||
@ -1031,7 +1006,7 @@ class TestWorkflowService:
|
||||
assert result.icon_type == conversion_args["icon_type"]
|
||||
assert result.icon_background == conversion_args["icon_background"]
|
||||
|
||||
def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers):
|
||||
def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when attempting to convert unsupported app mode.
|
||||
|
||||
@ -1046,9 +1021,7 @@ class TestWorkflowService:
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
conversion_args = {"name": "Test"}
|
||||
@ -1057,7 +1030,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"):
|
||||
workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args)
|
||||
|
||||
def test_validate_features_structure_advanced_chat(self, db_session_with_containers):
|
||||
def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test feature structure validation for advanced chat mode apps.
|
||||
|
||||
@ -1069,9 +1042,7 @@ class TestWorkflowService:
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
features = {
|
||||
@ -1088,7 +1059,7 @@ class TestWorkflowService:
|
||||
# The exact behavior depends on the AdvancedChatAppConfigManager implementation
|
||||
assert result is not None or isinstance(result, dict)
|
||||
|
||||
def test_validate_features_structure_workflow(self, db_session_with_containers):
|
||||
def test_validate_features_structure_workflow(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test feature structure validation for workflow mode apps.
|
||||
|
||||
@ -1100,9 +1071,7 @@ class TestWorkflowService:
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
features = {"workflow_config": {"max_steps": 10, "timeout": 300}}
|
||||
@ -1115,7 +1084,7 @@ class TestWorkflowService:
|
||||
# The exact behavior depends on the WorkflowAppConfigManager implementation
|
||||
assert result is not None or isinstance(result, dict)
|
||||
|
||||
def test_validate_features_structure_invalid_mode(self, db_session_with_containers):
|
||||
def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when validating features for invalid app mode.
|
||||
|
||||
@ -1127,9 +1096,7 @@ class TestWorkflowService:
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
app.mode = "invalid_mode" # Invalid mode
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
features = {"test": "value"}
|
||||
@ -1138,7 +1105,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"):
|
||||
workflow_service.validate_features_structure(app_model=app, features=features)
|
||||
|
||||
def test_update_workflow_success(self, db_session_with_containers):
|
||||
def test_update_workflow_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful workflow update with allowed fields.
|
||||
|
||||
@ -1152,16 +1119,14 @@ class TestWorkflowService:
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"}
|
||||
|
||||
# Act
|
||||
result = workflow_service.update_workflow(
|
||||
session=db.session,
|
||||
session=db_session_with_containers,
|
||||
workflow_id=workflow.id,
|
||||
tenant_id=workflow.tenant_id,
|
||||
account_id=account.id,
|
||||
@ -1174,7 +1139,7 @@ class TestWorkflowService:
|
||||
assert result.marked_comment == update_data["marked_comment"]
|
||||
assert result.updated_by == account.id
|
||||
|
||||
def test_update_workflow_not_found(self, db_session_with_containers):
|
||||
def test_update_workflow_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test workflow update when workflow doesn't exist.
|
||||
|
||||
@ -1186,15 +1151,13 @@ class TestWorkflowService:
|
||||
account = self._create_test_account(db_session_with_containers, fake)
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
non_existent_workflow_id = fake.uuid4()
|
||||
update_data = {"marked_name": "Test"}
|
||||
|
||||
# Act
|
||||
result = workflow_service.update_workflow(
|
||||
session=db.session,
|
||||
session=db_session_with_containers,
|
||||
workflow_id=non_existent_workflow_id,
|
||||
tenant_id=app.tenant_id,
|
||||
account_id=account.id,
|
||||
@ -1204,7 +1167,7 @@ class TestWorkflowService:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers):
|
||||
def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that workflow update ignores disallowed fields.
|
||||
|
||||
@ -1218,9 +1181,7 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
original_name = workflow.marked_name
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
update_data = {
|
||||
@ -1231,7 +1192,7 @@ class TestWorkflowService:
|
||||
|
||||
# Act
|
||||
result = workflow_service.update_workflow(
|
||||
session=db.session,
|
||||
session=db_session_with_containers,
|
||||
workflow_id=workflow.id,
|
||||
tenant_id=workflow.tenant_id,
|
||||
account_id=account.id,
|
||||
@ -1245,7 +1206,7 @@ class TestWorkflowService:
|
||||
assert result.graph == workflow.graph
|
||||
assert result.features == workflow.features
|
||||
|
||||
def test_delete_workflow_success(self, db_session_with_containers):
|
||||
def test_delete_workflow_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful workflow deletion.
|
||||
|
||||
@ -1262,25 +1223,23 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
workflow.version = "2024.01.01.001" # Published version
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Act
|
||||
result = workflow_service.delete_workflow(
|
||||
session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id
|
||||
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
# Verify workflow is actually deleted
|
||||
deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first()
|
||||
deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first()
|
||||
assert deleted_workflow is None
|
||||
|
||||
def test_delete_workflow_draft_error(self, db_session_with_containers):
|
||||
def test_delete_workflow_draft_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when attempting to delete a draft workflow.
|
||||
|
||||
@ -1296,9 +1255,7 @@ class TestWorkflowService:
|
||||
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
|
||||
# Keep as draft version
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -1306,9 +1263,11 @@ class TestWorkflowService:
|
||||
from services.errors.workflow_service import DraftWorkflowDeletionError
|
||||
|
||||
with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"):
|
||||
workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id)
|
||||
workflow_service.delete_workflow(
|
||||
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
|
||||
)
|
||||
|
||||
def test_delete_workflow_in_use_error(self, db_session_with_containers):
|
||||
def test_delete_workflow_in_use_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when attempting to delete a workflow that's in use by an app.
|
||||
|
||||
@ -1327,9 +1286,7 @@ class TestWorkflowService:
|
||||
# Associate workflow with app
|
||||
app.workflow_id = workflow.id
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
@ -1337,9 +1294,11 @@ class TestWorkflowService:
|
||||
from services.errors.workflow_service import WorkflowInUseError
|
||||
|
||||
with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"):
|
||||
workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id)
|
||||
workflow_service.delete_workflow(
|
||||
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
|
||||
)
|
||||
|
||||
def test_delete_workflow_not_found_error(self, db_session_with_containers):
|
||||
def test_delete_workflow_not_found_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test error when attempting to delete a non-existent workflow.
|
||||
|
||||
@ -1351,17 +1310,15 @@ class TestWorkflowService:
|
||||
app = self._create_test_app(db_session_with_containers, fake)
|
||||
non_existent_workflow_id = fake.uuid4()
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"):
|
||||
workflow_service.delete_workflow(
|
||||
session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id
|
||||
session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id
|
||||
)
|
||||
|
||||
def test_run_free_workflow_node_success(self, db_session_with_containers):
|
||||
def test_run_free_workflow_node_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful execution of a free workflow node.
|
||||
|
||||
@ -1413,7 +1370,7 @@ class TestWorkflowService:
|
||||
assert result.workflow_id == "" # No workflow ID for free nodes
|
||||
assert result.index == 1
|
||||
|
||||
def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers):
|
||||
def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test execution of a free workflow node with complex input data.
|
||||
|
||||
@ -1454,7 +1411,7 @@ class TestWorkflowService:
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert any(keyword in error_msg for keyword in ["start", "not supported", "external"])
|
||||
|
||||
def test_handle_node_run_result_success(self, db_session_with_containers):
|
||||
def test_handle_node_run_result_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful handling of node run results.
|
||||
|
||||
@ -1529,7 +1486,7 @@ class TestWorkflowService:
|
||||
assert result.outputs is not None
|
||||
assert result.process_data is not None
|
||||
|
||||
def test_handle_node_run_result_failure(self, db_session_with_containers):
|
||||
def test_handle_node_run_result_failure(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling of failed node run results.
|
||||
|
||||
@ -1598,7 +1555,7 @@ class TestWorkflowService:
|
||||
assert result.error is not None
|
||||
assert "Test error message" in str(result.error)
|
||||
|
||||
def test_handle_node_run_result_continue_on_error(self, db_session_with_containers):
|
||||
def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling of node run results with continue_on_error strategy.
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from services.workspace_service import WorkspaceService
|
||||
@ -29,7 +30,7 @@ class TestWorkspaceService:
|
||||
"dify_config": mock_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -50,10 +51,8 @@ class TestWorkspaceService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
@ -62,8 +61,8 @@ class TestWorkspaceService:
|
||||
plan="basic",
|
||||
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join with owner role
|
||||
join = TenantAccountJoin(
|
||||
@ -72,15 +71,15 @@ class TestWorkspaceService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of tenant information with all features enabled.
|
||||
|
||||
@ -121,13 +120,12 @@ class TestWorkspaceService:
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_without_custom_config(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval when custom config features are disabled.
|
||||
@ -167,13 +165,12 @@ class TestWorkspaceService:
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_normal_user_role(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for normal user role without privileged features.
|
||||
@ -191,11 +188,14 @@ class TestWorkspaceService:
|
||||
)
|
||||
|
||||
# Update the join to have normal role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant.id, account_id=account.id)
|
||||
.first()
|
||||
)
|
||||
join.role = TenantAccountRole.NORMAL
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks for feature service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
@ -220,11 +220,11 @@ class TestWorkspaceService:
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_admin_role_and_logo_replacement(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for admin role with logo replacement enabled.
|
||||
@ -242,11 +242,14 @@ class TestWorkspaceService:
|
||||
)
|
||||
|
||||
# Update the join to have admin role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant.id, account_id=account.id)
|
||||
.first()
|
||||
)
|
||||
join.role = TenantAccountRole.ADMIN
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
@ -268,10 +271,12 @@ class TestWorkspaceService:
|
||||
assert "replace_webapp_logo" in result["custom_config"]
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_tenant_info_with_tenant_none(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval when tenant parameter is None.
|
||||
|
||||
@ -290,7 +295,7 @@ class TestWorkspaceService:
|
||||
assert result is None
|
||||
|
||||
def test_get_tenant_info_with_custom_config_variations(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval with various custom config configurations.
|
||||
@ -323,10 +328,8 @@ class TestWorkspaceService:
|
||||
# Update tenant custom config
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tenant.custom_config = json.dumps(config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
@ -353,11 +356,11 @@ class TestWorkspaceService:
|
||||
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_editor_role_and_limited_permissions(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for editor role with limited permissions.
|
||||
@ -375,11 +378,14 @@ class TestWorkspaceService:
|
||||
)
|
||||
|
||||
# Update the join to have editor role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant.id, account_id=account.id)
|
||||
.first()
|
||||
)
|
||||
join.role = TenantAccountRole.EDITOR
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
@ -400,11 +406,11 @@ class TestWorkspaceService:
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_dataset_operator_role(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval for dataset operator role.
|
||||
@ -422,11 +428,14 @@ class TestWorkspaceService:
|
||||
)
|
||||
|
||||
# Update the join to have dataset operator role
|
||||
from extensions.ext_database import db
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
join = (
|
||||
db_session_with_containers.query(TenantAccountJoin)
|
||||
.filter_by(tenant_id=tenant.id, account_id=account.id)
|
||||
.first()
|
||||
)
|
||||
join.role = TenantAccountRole.DATASET_OPERATOR
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks for feature service and tenant service
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
@ -447,11 +456,11 @@ class TestWorkspaceService:
|
||||
assert "custom_config" not in result
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_with_complex_custom_config_scenarios(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant info retrieval with complex custom config scenarios.
|
||||
@ -491,10 +500,8 @@ class TestWorkspaceService:
|
||||
# Update tenant custom config
|
||||
import json
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tenant.custom_config = json.dumps(config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mocks
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
@ -525,5 +532,5 @@ class TestWorkspaceService:
|
||||
assert result["custom_config"]["remove_webapp_brand"] is False
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(tenant)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ApiProviderSchemaType
|
||||
from models import Account, Tenant
|
||||
@ -34,7 +35,7 @@ class TestApiToolManageService:
|
||||
"provider_controller": mock_provider_controller,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -55,18 +56,16 @@ class TestApiToolManageService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
@ -77,8 +76,8 @@ class TestApiToolManageService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
@ -118,7 +117,7 @@ class TestApiToolManageService:
|
||||
"""
|
||||
|
||||
def test_parser_api_schema_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful parsing of API schema.
|
||||
@ -163,7 +162,7 @@ class TestApiToolManageService:
|
||||
assert api_key_value_field["default"] == ""
|
||||
|
||||
def test_parser_api_schema_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parsing of invalid API schema.
|
||||
@ -183,7 +182,7 @@ class TestApiToolManageService:
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_parser_api_schema_malformed_json(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test parsing of malformed JSON schema.
|
||||
@ -203,7 +202,7 @@ class TestApiToolManageService:
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles.
|
||||
@ -233,7 +232,7 @@ class TestApiToolManageService:
|
||||
assert tool_bundle.operation_id == "testOperation"
|
||||
|
||||
def test_convert_schema_to_tool_bundles_with_extra_info(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of schema to tool bundles with extra info.
|
||||
@ -259,7 +258,7 @@ class TestApiToolManageService:
|
||||
assert isinstance(schema_type, str)
|
||||
|
||||
def test_convert_schema_to_tool_bundles_invalid_schema(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of invalid schema to tool bundles.
|
||||
@ -279,7 +278,7 @@ class TestApiToolManageService:
|
||||
assert "invalid schema" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_success(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful creation of API tool provider.
|
||||
@ -324,10 +323,9 @@ class TestApiToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
@ -347,7 +345,7 @@ class TestApiToolManageService:
|
||||
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
|
||||
|
||||
def test_create_api_tool_provider_duplicate_name(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with duplicate name.
|
||||
@ -404,7 +402,7 @@ class TestApiToolManageService:
|
||||
assert f"provider {provider_name} already exists" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_invalid_schema_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with invalid schema type.
|
||||
@ -436,7 +434,7 @@ class TestApiToolManageService:
|
||||
assert "validation error" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_missing_auth_type(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test creation of API tool provider with missing auth type.
|
||||
@ -479,7 +477,7 @@ class TestApiToolManageService:
|
||||
assert "auth_type is required" in str(exc_info.value)
|
||||
|
||||
def test_create_api_tool_provider_with_api_key_auth(
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
|
||||
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful creation of API tool provider with API key authentication.
|
||||
@ -522,10 +520,9 @@ class TestApiToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
provider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
db_session_with_containers.query(ApiToolProvider)
|
||||
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models import Account, Tenant
|
||||
@ -41,7 +42,7 @@ class TestMCPToolManageService:
|
||||
"tool_transform_service": mock_tool_transform_service,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -62,18 +63,16 @@ class TestMCPToolManageService:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
@ -84,8 +83,8 @@ class TestMCPToolManageService:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
@ -93,7 +92,7 @@ class TestMCPToolManageService:
|
||||
return account, tenant
|
||||
|
||||
def _create_test_mcp_provider(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id
|
||||
):
|
||||
"""
|
||||
Helper method to create a test MCP tool provider for testing.
|
||||
@ -124,15 +123,13 @@ class TestMCPToolManageService:
|
||||
sse_read_timeout=300.0,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(mcp_provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(mcp_provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return mcp_provider
|
||||
|
||||
def test_get_mcp_provider_by_provider_id_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of MCP provider by provider ID.
|
||||
@ -153,9 +150,8 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -166,12 +162,12 @@ class TestMCPToolManageService:
|
||||
assert result.user_id == account.id
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.server_identifier == mcp_provider.server_identifier
|
||||
|
||||
def test_get_mcp_provider_by_provider_id_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when MCP provider is not found by provider ID.
|
||||
@ -190,14 +186,13 @@ class TestMCPToolManageService:
|
||||
non_existent_id = str(fake.uuid4())
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
|
||||
|
||||
def test_get_mcp_provider_by_provider_id_tenant_isolation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation when retrieving MCP provider by provider ID.
|
||||
@ -223,14 +218,13 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Act & Assert: Verify tenant isolation
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
|
||||
|
||||
def test_get_mcp_provider_by_server_identifier_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful retrieval of MCP provider by server identifier.
|
||||
@ -251,9 +245,8 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -264,12 +257,12 @@ class TestMCPToolManageService:
|
||||
assert result.user_id == account.id
|
||||
|
||||
# Verify database state
|
||||
db.session.refresh(result)
|
||||
db_session_with_containers.refresh(result)
|
||||
assert result.id is not None
|
||||
assert result.name == mcp_provider.name
|
||||
|
||||
def test_get_mcp_provider_by_server_identifier_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when MCP provider is not found by server identifier.
|
||||
@ -288,14 +281,13 @@ class TestMCPToolManageService:
|
||||
non_existent_identifier = str(fake.uuid4())
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
|
||||
|
||||
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation when retrieving MCP provider by server identifier.
|
||||
@ -321,13 +313,12 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Act & Assert: Verify tenant isolation
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
|
||||
|
||||
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful creation of MCP provider.
|
||||
|
||||
@ -365,9 +356,8 @@ class TestMCPToolManageService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant.id,
|
||||
name="Test MCP Provider",
|
||||
@ -389,10 +379,9 @@ class TestMCPToolManageService:
|
||||
assert result.type == ToolProviderType.MCP
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_provider = (
|
||||
db.session.query(MCPToolProvider)
|
||||
db_session_with_containers.query(MCPToolProvider)
|
||||
.filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider")
|
||||
.first()
|
||||
)
|
||||
@ -410,7 +399,9 @@ class TestMCPToolManageService:
|
||||
)
|
||||
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once()
|
||||
|
||||
def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_mcp_provider_duplicate_name(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when creating MCP provider with duplicate name.
|
||||
|
||||
@ -427,9 +418,8 @@ class TestMCPToolManageService:
|
||||
|
||||
# Create first provider
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.create_provider(
|
||||
tenant_id=tenant.id,
|
||||
name="Test MCP Provider",
|
||||
@ -463,7 +453,7 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
def test_create_mcp_provider_duplicate_server_url(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when creating MCP provider with duplicate server URL.
|
||||
@ -481,9 +471,8 @@ class TestMCPToolManageService:
|
||||
|
||||
# Create first provider
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.create_provider(
|
||||
tenant_id=tenant.id,
|
||||
name="Test MCP Provider 1",
|
||||
@ -517,7 +506,7 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
def test_create_mcp_provider_duplicate_server_identifier(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when creating MCP provider with duplicate server identifier.
|
||||
@ -535,9 +524,8 @@ class TestMCPToolManageService:
|
||||
|
||||
# Create first provider
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.create_provider(
|
||||
tenant_id=tenant.id,
|
||||
name="Test MCP Provider 1",
|
||||
@ -570,7 +558,7 @@ class TestMCPToolManageService:
|
||||
),
|
||||
)
|
||||
|
||||
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful retrieval of MCP tools for a tenant.
|
||||
|
||||
@ -602,9 +590,7 @@ class TestMCPToolManageService:
|
||||
)
|
||||
provider3.name = "Gamma Provider"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Setup mock for transformation service
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
@ -647,9 +633,8 @@ class TestMCPToolManageService:
|
||||
]
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result = service.list_providers(tenant_id=tenant.id, for_list=True)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -666,7 +651,9 @@ class TestMCPToolManageService:
|
||||
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3
|
||||
)
|
||||
|
||||
def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_retrieve_mcp_tools_empty_list(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test retrieval of MCP tools when tenant has no providers.
|
||||
|
||||
@ -684,9 +671,8 @@ class TestMCPToolManageService:
|
||||
# No MCP providers created for this tenant
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result = service.list_providers(tenant_id=tenant.id, for_list=False)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -697,7 +683,9 @@ class TestMCPToolManageService:
|
||||
# Verify no transformation service calls for empty list
|
||||
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called()
|
||||
|
||||
def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_retrieve_mcp_tools_tenant_isolation(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation when retrieving MCP tools.
|
||||
|
||||
@ -756,9 +744,8 @@ class TestMCPToolManageService:
|
||||
]
|
||||
|
||||
# Act: Execute the method under test for both tenants
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
|
||||
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
|
||||
|
||||
@ -769,7 +756,7 @@ class TestMCPToolManageService:
|
||||
assert result2[0].id == provider2.id
|
||||
|
||||
def test_list_mcp_tool_from_remote_server_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful listing of MCP tools from remote server.
|
||||
@ -797,9 +784,7 @@ class TestMCPToolManageService:
|
||||
mcp_provider.authed = True # Provider must be authenticated to list tools
|
||||
mcp_provider.tools = "[]"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the decryption process at the rsa level to avoid key file issues
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
@ -821,9 +806,8 @@ class TestMCPToolManageService:
|
||||
mock_client_instance.list_tools.return_value = mock_tools
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -834,7 +818,7 @@ class TestMCPToolManageService:
|
||||
# Note: server_url is mocked, so we skip that assertion to avoid encryption issues
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(mcp_provider)
|
||||
db_session_with_containers.refresh(mcp_provider)
|
||||
assert mcp_provider.authed is True
|
||||
assert mcp_provider.tools != "[]"
|
||||
assert mcp_provider.updated_at is not None
|
||||
@ -844,7 +828,7 @@ class TestMCPToolManageService:
|
||||
mock_mcp_client.assert_called_once()
|
||||
|
||||
def test_list_mcp_tool_from_remote_server_auth_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when MCP server requires authentication.
|
||||
@ -871,9 +855,7 @@ class TestMCPToolManageService:
|
||||
mcp_provider.authed = False
|
||||
mcp_provider.tools = "[]"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the decryption process at the rsa level to avoid key file issues
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
@ -887,19 +869,18 @@ class TestMCPToolManageService:
|
||||
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Please auth the tool first"):
|
||||
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||
|
||||
# Verify database state was not changed
|
||||
db.session.refresh(mcp_provider)
|
||||
db_session_with_containers.refresh(mcp_provider)
|
||||
assert mcp_provider.authed is False
|
||||
assert mcp_provider.tools == "[]"
|
||||
|
||||
def test_list_mcp_tool_from_remote_server_connection_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when MCP server connection fails.
|
||||
@ -926,9 +907,7 @@ class TestMCPToolManageService:
|
||||
mcp_provider.authed = True # Provider must be authenticated to test connection errors
|
||||
mcp_provider.tools = "[]"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the decryption process at the rsa level to avoid key file issues
|
||||
with patch("libs.rsa.decrypt") as mock_decrypt:
|
||||
@ -942,18 +921,17 @@ class TestMCPToolManageService:
|
||||
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
|
||||
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||
|
||||
# Verify database state was not changed
|
||||
db.session.refresh(mcp_provider)
|
||||
db_session_with_containers.refresh(mcp_provider)
|
||||
assert mcp_provider.authed is True # Provider remains authenticated
|
||||
assert mcp_provider.tools == "[]"
|
||||
|
||||
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful deletion of MCP tool.
|
||||
|
||||
@ -974,20 +952,19 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Verify provider exists
|
||||
from extensions.ext_database import db
|
||||
|
||||
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
|
||||
assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
# Provider should be deleted from database
|
||||
deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
|
||||
deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
|
||||
assert deleted_provider is None
|
||||
|
||||
def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test error handling when deleting non-existent MCP tool.
|
||||
|
||||
@ -1005,13 +982,14 @@ class TestMCPToolManageService:
|
||||
non_existent_id = str(fake.uuid4())
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
|
||||
|
||||
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_delete_mcp_tool_tenant_isolation(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tenant isolation when deleting MCP tool.
|
||||
|
||||
@ -1036,18 +1014,16 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Act & Assert: Verify tenant isolation
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool not found"):
|
||||
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
|
||||
|
||||
# Verify provider still exists in tenant1
|
||||
from extensions.ext_database import db
|
||||
|
||||
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
|
||||
assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
|
||||
|
||||
def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful update of MCP provider.
|
||||
|
||||
@ -1070,14 +1046,12 @@ class TestMCPToolManageService:
|
||||
original_name = mcp_provider.name
|
||||
original_icon = mcp_provider.icon
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.update_provider(
|
||||
tenant_id=tenant.id,
|
||||
provider_id=mcp_provider.id,
|
||||
@ -1094,7 +1068,7 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
db.session.refresh(mcp_provider)
|
||||
db_session_with_containers.refresh(mcp_provider)
|
||||
assert mcp_provider.name == "Updated MCP Provider"
|
||||
assert mcp_provider.server_identifier == "updated_identifier_123"
|
||||
assert mcp_provider.timeout == 45.0
|
||||
@ -1108,7 +1082,9 @@ class TestMCPToolManageService:
|
||||
assert icon_data["content"] == "🚀"
|
||||
assert icon_data["background"] == "#4ECDC4"
|
||||
|
||||
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_mcp_provider_duplicate_name(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when updating MCP provider with duplicate name.
|
||||
|
||||
@ -1134,15 +1110,12 @@ class TestMCPToolManageService:
|
||||
)
|
||||
provider2.name = "Second Provider"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling for duplicate name
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
|
||||
service.update_provider(
|
||||
tenant_id=tenant.id,
|
||||
@ -1160,7 +1133,7 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
def test_update_mcp_provider_credentials_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful update of MCP provider credentials.
|
||||
@ -1185,9 +1158,7 @@ class TestMCPToolManageService:
|
||||
mcp_provider.authed = False
|
||||
mcp_provider.tools = "[]"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the provider controller and encryption
|
||||
with (
|
||||
@ -1202,9 +1173,8 @@ class TestMCPToolManageService:
|
||||
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.update_provider_credentials(
|
||||
provider_id=mcp_provider.id,
|
||||
tenant_id=tenant.id,
|
||||
@ -1213,7 +1183,7 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
db.session.refresh(mcp_provider)
|
||||
db_session_with_containers.refresh(mcp_provider)
|
||||
assert mcp_provider.authed is True
|
||||
assert mcp_provider.updated_at is not None
|
||||
|
||||
@ -1225,7 +1195,7 @@ class TestMCPToolManageService:
|
||||
assert "new_key" in credentials
|
||||
|
||||
def test_update_mcp_provider_credentials_not_authed(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test update of MCP provider credentials when not authenticated.
|
||||
@ -1249,9 +1219,7 @@ class TestMCPToolManageService:
|
||||
mcp_provider.authed = True
|
||||
mcp_provider.tools = '[{"name": "test_tool"}]'
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the provider controller and encryption
|
||||
with (
|
||||
@ -1266,9 +1234,8 @@ class TestMCPToolManageService:
|
||||
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service = MCPToolManageService(db_session_with_containers)
|
||||
service.update_provider_credentials(
|
||||
provider_id=mcp_provider.id,
|
||||
tenant_id=tenant.id,
|
||||
@ -1277,12 +1244,14 @@ class TestMCPToolManageService:
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
db.session.refresh(mcp_provider)
|
||||
db_session_with_containers.refresh(mcp_provider)
|
||||
assert mcp_provider.authed is False
|
||||
assert mcp_provider.tools == "[]"
|
||||
assert mcp_provider.updated_at is not None
|
||||
|
||||
def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_re_connect_mcp_provider_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful reconnection to MCP provider.
|
||||
|
||||
@ -1343,7 +1312,9 @@ class TestMCPToolManageService:
|
||||
sse_read_timeout=mcp_provider.sse_read_timeout,
|
||||
)
|
||||
|
||||
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_re_connect_mcp_provider_auth_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test reconnection to MCP provider when authentication fails.
|
||||
|
||||
@ -1385,7 +1356,7 @@ class TestMCPToolManageService:
|
||||
assert result.encrypted_credentials == "{}"
|
||||
|
||||
def test_re_connect_mcp_provider_connection_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test reconnection to MCP provider when connection fails.
|
||||
|
||||
@ -2,6 +2,7 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@ -27,7 +28,7 @@ class TestToolTransformService:
|
||||
}
|
||||
|
||||
def _create_test_tool_provider(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api"
|
||||
):
|
||||
"""
|
||||
Helper method to create a test tool provider for testing.
|
||||
@ -89,14 +90,12 @@ class TestToolTransformService:
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return provider
|
||||
|
||||
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful plugin icon URL generation.
|
||||
|
||||
@ -126,7 +125,7 @@ class TestToolTransformService:
|
||||
assert result == expected_url
|
||||
|
||||
def test_get_plugin_icon_url_with_empty_console_url(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test plugin icon URL generation when CONSOLE_API_URL is empty.
|
||||
@ -156,7 +155,7 @@ class TestToolTransformService:
|
||||
assert result == expected_url
|
||||
|
||||
def test_get_tool_provider_icon_url_builtin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for builtin providers.
|
||||
@ -194,7 +193,7 @@ class TestToolTransformService:
|
||||
assert result == expected_encoded
|
||||
|
||||
def test_get_tool_provider_icon_url_api_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for API providers.
|
||||
@ -220,7 +219,7 @@ class TestToolTransformService:
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_api_invalid_json(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tool provider icon URL generation for API providers with invalid JSON.
|
||||
@ -246,7 +245,7 @@ class TestToolTransformService:
|
||||
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
|
||||
|
||||
def test_get_tool_provider_icon_url_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for workflow providers.
|
||||
@ -271,7 +270,7 @@ class TestToolTransformService:
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_mcp_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for MCP providers.
|
||||
@ -296,7 +295,7 @@ class TestToolTransformService:
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_unknown_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tool provider icon URL generation for unknown provider types.
|
||||
@ -317,7 +316,9 @@ class TestToolTransformService:
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == ""
|
||||
|
||||
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_repack_provider_dict_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful provider repacking with dictionary input.
|
||||
|
||||
@ -341,7 +342,9 @@ class TestToolTransformService:
|
||||
# Note: provider name may contain spaces that get URL encoded
|
||||
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
|
||||
|
||||
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_repack_provider_entity_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful provider repacking with ToolProviderApiEntity input.
|
||||
|
||||
@ -389,7 +392,7 @@ class TestToolTransformService:
|
||||
assert "test_icon_dark.png" in provider.icon_dark
|
||||
|
||||
def test_repack_provider_entity_no_plugin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
|
||||
@ -435,7 +438,9 @@ class TestToolTransformService:
|
||||
assert provider.icon_dark["background"] == "#252525"
|
||||
assert provider.icon_dark["content"] == "🔧"
|
||||
|
||||
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_repack_provider_entity_no_dark_icon(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test provider repacking with ToolProviderApiEntity input without dark icon.
|
||||
|
||||
@ -477,7 +482,7 @@ class TestToolTransformService:
|
||||
assert provider.icon_dark == ""
|
||||
|
||||
def test_builtin_provider_to_user_provider_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of builtin provider to user provider.
|
||||
@ -545,7 +550,7 @@ class TestToolTransformService:
|
||||
assert result.original_credentials == {"api_key": "decrypted_key"}
|
||||
|
||||
def test_builtin_provider_to_user_provider_plugin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of builtin provider to user provider with plugin.
|
||||
@ -589,7 +594,7 @@ class TestToolTransformService:
|
||||
assert result.allow_delete is False
|
||||
|
||||
def test_builtin_provider_to_user_provider_no_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of builtin provider to user provider without credentials.
|
||||
@ -630,7 +635,9 @@ class TestToolTransformService:
|
||||
assert result.allow_delete is False
|
||||
assert result.masked_credentials == {"api_key": ""}
|
||||
|
||||
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_api_provider_to_controller_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of API provider to controller.
|
||||
|
||||
@ -655,10 +662,8 @@ class TestToolTransformService:
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
@ -669,7 +674,7 @@ class TestToolTransformService:
|
||||
# Additional assertions would depend on the actual controller implementation
|
||||
|
||||
def test_api_provider_to_controller_api_key_query(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of API provider to controller with api_key_query auth type.
|
||||
@ -693,10 +698,8 @@ class TestToolTransformService:
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
@ -706,7 +709,7 @@ class TestToolTransformService:
|
||||
assert hasattr(result, "from_db")
|
||||
|
||||
def test_api_provider_to_controller_backward_compatibility(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of API provider to controller with backward compatibility auth types.
|
||||
@ -731,10 +734,8 @@ class TestToolTransformService:
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
@ -744,7 +745,7 @@ class TestToolTransformService:
|
||||
assert hasattr(result, "from_db")
|
||||
|
||||
def test_workflow_provider_to_controller_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of workflow provider to controller.
|
||||
@ -769,10 +770,8 @@ class TestToolTransformService:
|
||||
parameter_configuration="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(provider)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
|
||||
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:
|
||||
|
||||
@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
@ -63,7 +64,7 @@ class TestWorkflowToolManageService:
|
||||
"tool_transform_service": mock_tool_transform_service,
|
||||
}
|
||||
|
||||
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test app and account for testing.
|
||||
|
||||
@ -119,14 +120,12 @@ class TestWorkflowToolManageService:
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Update app to reference the workflow
|
||||
app.workflow_id = workflow.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app, account, workflow
|
||||
|
||||
@ -153,7 +152,9 @@ class TestWorkflowToolManageService:
|
||||
),
|
||||
]
|
||||
|
||||
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_create_workflow_tool_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful workflow tool creation with valid parameters.
|
||||
|
||||
@ -198,11 +199,10 @@ class TestWorkflowToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check if workflow tool provider was created
|
||||
created_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
@ -230,7 +230,7 @@ class TestWorkflowToolManageService:
|
||||
].workflow_provider_to_controller.assert_called_once()
|
||||
|
||||
def test_create_workflow_tool_duplicate_name_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when name already exists.
|
||||
@ -280,10 +280,9 @@ class TestWorkflowToolManageService:
|
||||
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
|
||||
|
||||
# Verify only one tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -293,7 +292,7 @@ class TestWorkflowToolManageService:
|
||||
assert tool_count == 1
|
||||
|
||||
def test_create_workflow_tool_invalid_app_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when app does not exist.
|
||||
@ -331,10 +330,9 @@ class TestWorkflowToolManageService:
|
||||
assert f"App {non_existent_app_id} not found" in str(exc_info.value)
|
||||
|
||||
# Verify no workflow tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -344,7 +342,7 @@ class TestWorkflowToolManageService:
|
||||
assert tool_count == 0
|
||||
|
||||
def test_create_workflow_tool_invalid_parameters_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when parameters are invalid.
|
||||
@ -387,10 +385,9 @@ class TestWorkflowToolManageService:
|
||||
assert "validation error" in str(exc_info.value).lower()
|
||||
|
||||
# Verify no workflow tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -400,7 +397,7 @@ class TestWorkflowToolManageService:
|
||||
assert tool_count == 0
|
||||
|
||||
def test_create_workflow_tool_duplicate_app_id_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when app_id already exists.
|
||||
@ -450,10 +447,9 @@ class TestWorkflowToolManageService:
|
||||
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
|
||||
|
||||
# Verify only one tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -463,7 +459,7 @@ class TestWorkflowToolManageService:
|
||||
assert tool_count == 1
|
||||
|
||||
def test_create_workflow_tool_workflow_not_found_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when app has no workflow.
|
||||
@ -481,10 +477,9 @@ class TestWorkflowToolManageService:
|
||||
)
|
||||
|
||||
# Remove workflow reference from app
|
||||
from extensions.ext_database import db
|
||||
|
||||
app.workflow_id = None
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Attempt to create workflow tool for app without workflow
|
||||
tool_parameters = self._create_test_workflow_tool_parameters()
|
||||
@ -505,7 +500,7 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Verify no workflow tool was created
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -515,7 +510,7 @@ class TestWorkflowToolManageService:
|
||||
assert tool_count == 0
|
||||
|
||||
def test_create_workflow_tool_human_input_node_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation fails when workflow contains human input nodes.
|
||||
@ -558,10 +553,8 @@ class TestWorkflowToolManageService:
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -570,7 +563,9 @@ class TestWorkflowToolManageService:
|
||||
|
||||
assert tool_count == 0
|
||||
|
||||
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_workflow_tool_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful workflow tool update with valid parameters.
|
||||
|
||||
@ -603,10 +598,9 @@ class TestWorkflowToolManageService:
|
||||
)
|
||||
|
||||
# Get the created tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_tool = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
@ -641,7 +635,7 @@ class TestWorkflowToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify database state was updated
|
||||
db.session.refresh(created_tool)
|
||||
db_session_with_containers.refresh(created_tool)
|
||||
assert created_tool is not None
|
||||
assert created_tool.name == updated_tool_name
|
||||
assert created_tool.label == updated_tool_label
|
||||
@ -658,7 +652,7 @@ class TestWorkflowToolManageService:
|
||||
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
|
||||
|
||||
def test_update_workflow_tool_human_input_node_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool update fails when workflow contains human input nodes.
|
||||
@ -689,10 +683,8 @@ class TestWorkflowToolManageService:
|
||||
parameters=initial_tool_parameters,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_tool = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
@ -712,7 +704,7 @@ class TestWorkflowToolManageService:
|
||||
]
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
WorkflowToolManageService.update_workflow_tool(
|
||||
@ -728,10 +720,12 @@ class TestWorkflowToolManageService:
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
|
||||
db.session.refresh(created_tool)
|
||||
db_session_with_containers.refresh(created_tool)
|
||||
assert created_tool.name == original_name
|
||||
|
||||
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_update_workflow_tool_not_found_error(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool update fails when tool does not exist.
|
||||
|
||||
@ -768,10 +762,9 @@ class TestWorkflowToolManageService:
|
||||
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
|
||||
|
||||
# Verify no workflow tool was created
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
)
|
||||
@ -781,7 +774,7 @@ class TestWorkflowToolManageService:
|
||||
assert tool_count == 0
|
||||
|
||||
def test_update_workflow_tool_same_name_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool update succeeds when keeping the same name.
|
||||
@ -813,10 +806,9 @@ class TestWorkflowToolManageService:
|
||||
)
|
||||
|
||||
# Get the created tool
|
||||
from extensions.ext_database import db
|
||||
|
||||
created_tool = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.app_id == app.id,
|
||||
@ -840,12 +832,12 @@ class TestWorkflowToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
# Verify tool still exists with the same name
|
||||
db.session.refresh(created_tool)
|
||||
db_session_with_containers.refresh(created_tool)
|
||||
assert created_tool.name == first_tool_name
|
||||
assert created_tool.updated_at is not None
|
||||
|
||||
def test_create_workflow_tool_with_file_parameter_default(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation with FILE parameter having a file object as default.
|
||||
@ -916,7 +908,7 @@ class TestWorkflowToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_workflow_tool_with_files_parameter_default(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
|
||||
@ -991,7 +983,7 @@ class TestWorkflowToolManageService:
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_create_workflow_tool_db_commit_before_validation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that database commit happens before validation, causing DB pollution on validation failure.
|
||||
@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService:
|
||||
|
||||
# Verify the tool was NOT created in database
|
||||
# This is the expected behavior (no pollution)
|
||||
from extensions.ext_database import db
|
||||
|
||||
tool_count = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
db_session_with_containers.query(WorkflowToolProvider)
|
||||
.where(
|
||||
WorkflowToolProvider.tenant_id == account.current_tenant.id,
|
||||
WorkflowToolProvider.name == tool_name,
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
DatasetEntity,
|
||||
@ -79,7 +80,7 @@ class TestWorkflowConverter:
|
||||
mock_config.app_model_config_dict = {}
|
||||
return mock_config
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -100,18 +101,16 @@ class TestWorkflowConverter:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
from models.account import TenantAccountJoin, TenantAccountRole
|
||||
@ -122,15 +121,17 @@ class TestWorkflowConverter:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account):
|
||||
def _create_test_app(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account
|
||||
):
|
||||
"""
|
||||
Helper method to create a test app for testing.
|
||||
|
||||
@ -163,10 +164,8 @@ class TestWorkflowConverter:
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create app model config
|
||||
app_model_config = AppModelConfig(
|
||||
@ -177,16 +176,16 @@ class TestWorkflowConverter:
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
db.session.add(app_model_config)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app_model_config)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Link app model config to app
|
||||
app.app_model_config_id = app_model_config.id
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return app
|
||||
|
||||
def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion of app to workflow.
|
||||
|
||||
@ -225,19 +224,18 @@ class TestWorkflowConverter:
|
||||
assert new_app.created_by == account.id
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(new_app)
|
||||
db_session_with_containers.refresh(new_app)
|
||||
assert new_app.id is not None
|
||||
|
||||
# Verify workflow was created
|
||||
workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first()
|
||||
workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first()
|
||||
assert workflow is not None
|
||||
assert workflow.tenant_id == app.tenant_id
|
||||
assert workflow.type == "chat"
|
||||
|
||||
def test_convert_to_workflow_without_app_model_config_error(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling when app model config is missing.
|
||||
@ -270,16 +268,14 @@ class TestWorkflowConverter:
|
||||
updated_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(app)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
# Check initial state
|
||||
initial_workflow_count = db.session.query(Workflow).count()
|
||||
initial_workflow_count = db_session_with_containers.query(Workflow).count()
|
||||
|
||||
with pytest.raises(ValueError, match="App model config is required"):
|
||||
workflow_converter.convert_to_workflow(
|
||||
@ -294,12 +290,12 @@ class TestWorkflowConverter:
|
||||
# Verify database state remains unchanged
|
||||
# The workflow creation happens in convert_app_model_config_to_workflow
|
||||
# which is called before the app_model_config check, so we need to clean up
|
||||
db.session.rollback()
|
||||
final_workflow_count = db.session.query(Workflow).count()
|
||||
db_session_with_containers.rollback()
|
||||
final_workflow_count = db_session_with_containers.query(Workflow).count()
|
||||
assert final_workflow_count == initial_workflow_count
|
||||
|
||||
def test_convert_app_model_config_to_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of app model config to workflow.
|
||||
@ -356,16 +352,17 @@ class TestWorkflowConverter:
|
||||
assert answer_node["id"] == "answer"
|
||||
|
||||
# Verify database state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(workflow)
|
||||
db_session_with_containers.refresh(workflow)
|
||||
assert workflow.id is not None
|
||||
|
||||
# Verify features were set
|
||||
features = json.loads(workflow._features) if workflow._features else {}
|
||||
assert isinstance(features, dict)
|
||||
|
||||
def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_convert_to_start_node_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion to start node.
|
||||
|
||||
@ -410,7 +407,9 @@ class TestWorkflowConverter:
|
||||
assert second_variable["label"] == "Number Input"
|
||||
assert second_variable["type"] == "number"
|
||||
|
||||
def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_convert_to_http_request_node_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion to HTTP request node.
|
||||
|
||||
@ -436,10 +435,8 @@ class TestWorkflowConverter:
|
||||
api_endpoint="https://api.example.com/test",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(api_based_extension)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(api_based_extension)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock encrypter
|
||||
mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key"
|
||||
@ -489,7 +486,7 @@ class TestWorkflowConverter:
|
||||
assert external_data_variable_node_mapping["external_data"] == code_node["id"]
|
||||
|
||||
def test_convert_to_knowledge_retrieval_node_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion to knowledge retrieval node.
|
||||
|
||||
@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
|
||||
@ -31,7 +31,9 @@ class TestAddDocumentToIndexTask:
|
||||
"index_processor": mock_processor,
|
||||
}
|
||||
|
||||
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_dataset_and_document(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset and document for testing.
|
||||
|
||||
@ -51,15 +53,15 @@ class TestAddDocumentToIndexTask:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -68,8 +70,8 @@ class TestAddDocumentToIndexTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
@ -81,8 +83,8 @@ class TestAddDocumentToIndexTask:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create document
|
||||
document = Document(
|
||||
@ -99,15 +101,15 @@ class TestAddDocumentToIndexTask:
|
||||
enabled=True,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property works correctly
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, document
|
||||
|
||||
def _create_test_segments(self, db_session_with_containers, document, dataset):
|
||||
def _create_test_segments(self, db_session_with_containers: Session, document, dataset):
|
||||
"""
|
||||
Helper method to create test document segments.
|
||||
|
||||
@ -138,13 +140,15 @@ class TestAddDocumentToIndexTask:
|
||||
status="completed",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
return segments
|
||||
|
||||
def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_add_document_to_index_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful document indexing with paragraph index type.
|
||||
|
||||
@ -180,9 +184,9 @@ class TestAddDocumentToIndexTask:
|
||||
mock_external_service_dependencies["index_processor"].load.assert_called_once()
|
||||
|
||||
# Verify database state changes
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is True
|
||||
assert segment.disabled_at is None
|
||||
assert segment.disabled_by is None
|
||||
@ -191,7 +195,7 @@ class TestAddDocumentToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_with_different_index_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test document indexing with different index types.
|
||||
@ -209,10 +213,10 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Update document to use different index type
|
||||
document.doc_form = IndexStructureType.QA_INDEX
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
# Create segments
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
@ -237,9 +241,9 @@ class TestAddDocumentToIndexTask:
|
||||
assert len(documents) == 3
|
||||
|
||||
# Verify database state changes
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is True
|
||||
assert segment.disabled_at is None
|
||||
assert segment.disabled_by is None
|
||||
@ -248,7 +252,7 @@ class TestAddDocumentToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_document_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent document.
|
||||
@ -275,7 +279,7 @@ class TestAddDocumentToIndexTask:
|
||||
# because indexing_cache_key is not defined in that case
|
||||
|
||||
def test_add_document_to_index_invalid_indexing_status(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of document with invalid indexing status.
|
||||
@ -294,7 +298,7 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Set invalid indexing status
|
||||
document.indexing_status = "processing"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the task
|
||||
add_document_to_index_task(document.id)
|
||||
@ -304,7 +308,7 @@ class TestAddDocumentToIndexTask:
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_add_document_to_index_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling when document's dataset doesn't exist.
|
||||
@ -326,14 +330,14 @@ class TestAddDocumentToIndexTask:
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Delete the dataset to simulate dataset not found scenario
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.delete(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the task
|
||||
add_document_to_index_task(document.id)
|
||||
|
||||
# Assert: Verify error handling
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is False
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
@ -348,7 +352,7 @@ class TestAddDocumentToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_with_parent_child_structure(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test document indexing with parent-child structure.
|
||||
@ -367,10 +371,10 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Update document to use parent-child index type
|
||||
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
# Create segments with mock child chunks
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
@ -413,9 +417,9 @@ class TestAddDocumentToIndexTask:
|
||||
assert len(doc.children) == 2 # Each document has 2 children
|
||||
|
||||
# Verify database state changes
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is True
|
||||
assert segment.disabled_at is None
|
||||
assert segment.disabled_by is None
|
||||
@ -424,7 +428,7 @@ class TestAddDocumentToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_with_already_enabled_segments(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test document indexing when segments are already enabled.
|
||||
@ -459,10 +463,10 @@ class TestAddDocumentToIndexTask:
|
||||
status="completed",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set up Redis cache key
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
@ -488,7 +492,7 @@ class TestAddDocumentToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_auto_disable_log_deletion(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that auto disable logs are properly deleted during indexing.
|
||||
@ -515,10 +519,10 @@ class TestAddDocumentToIndexTask:
|
||||
document_id=document.id,
|
||||
)
|
||||
log_entry.id = str(fake.uuid4())
|
||||
db.session.add(log_entry)
|
||||
db_session_with_containers.add(log_entry)
|
||||
auto_disable_logs.append(log_entry)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set up Redis cache key
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
@ -526,7 +530,9 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Verify logs exist before processing
|
||||
existing_logs = (
|
||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all()
|
||||
db_session_with_containers.query(DatasetAutoDisableLog)
|
||||
.where(DatasetAutoDisableLog.document_id == document.id)
|
||||
.all()
|
||||
)
|
||||
assert len(existing_logs) == 2
|
||||
|
||||
@ -535,7 +541,9 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Assert: Verify auto disable logs were deleted
|
||||
remaining_logs = (
|
||||
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all()
|
||||
db_session_with_containers.query(DatasetAutoDisableLog)
|
||||
.where(DatasetAutoDisableLog.document_id == document.id)
|
||||
.all()
|
||||
)
|
||||
assert len(remaining_logs) == 0
|
||||
|
||||
@ -547,14 +555,14 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Verify segments were enabled
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is True
|
||||
|
||||
# Verify redis cache was cleared
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_general_exception_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test general exception handling during indexing process.
|
||||
@ -584,7 +592,7 @@ class TestAddDocumentToIndexTask:
|
||||
add_document_to_index_task(document.id)
|
||||
|
||||
# Assert: Verify error handling
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is False
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error is not None
|
||||
@ -593,14 +601,14 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Verify segments were not enabled due to error
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is False # Should remain disabled due to error
|
||||
|
||||
# Verify redis cache was still cleared despite error
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_segment_filtering_edge_cases(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test segment filtering with various edge cases.
|
||||
@ -638,7 +646,7 @@ class TestAddDocumentToIndexTask:
|
||||
status="completed",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment1)
|
||||
db_session_with_containers.add(segment1)
|
||||
segments.append(segment1)
|
||||
|
||||
# Segment 2: Should be processed (enabled=True, status="completed")
|
||||
@ -658,7 +666,7 @@ class TestAddDocumentToIndexTask:
|
||||
status="completed",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment2)
|
||||
db_session_with_containers.add(segment2)
|
||||
segments.append(segment2)
|
||||
|
||||
# Segment 3: Should NOT be processed (enabled=False, status="processing")
|
||||
@ -677,7 +685,7 @@ class TestAddDocumentToIndexTask:
|
||||
status="processing", # Not completed
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment3)
|
||||
db_session_with_containers.add(segment3)
|
||||
segments.append(segment3)
|
||||
|
||||
# Segment 4: Should be processed (enabled=False, status="completed")
|
||||
@ -696,10 +704,10 @@ class TestAddDocumentToIndexTask:
|
||||
status="completed",
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment4)
|
||||
db_session_with_containers.add(segment4)
|
||||
segments.append(segment4)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set up Redis cache key
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
@ -728,11 +736,11 @@ class TestAddDocumentToIndexTask:
|
||||
assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3
|
||||
|
||||
# Verify database state changes
|
||||
db.session.refresh(document)
|
||||
db.session.refresh(segment1)
|
||||
db.session.refresh(segment2)
|
||||
db.session.refresh(segment3)
|
||||
db.session.refresh(segment4)
|
||||
db_session_with_containers.refresh(document)
|
||||
db_session_with_containers.refresh(segment1)
|
||||
db_session_with_containers.refresh(segment2)
|
||||
db_session_with_containers.refresh(segment3)
|
||||
db_session_with_containers.refresh(segment4)
|
||||
|
||||
# All segments should be enabled because the task updates ALL segments for the document
|
||||
assert segment1.enabled is True
|
||||
@ -744,7 +752,7 @@ class TestAddDocumentToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_add_document_to_index_comprehensive_error_scenarios(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test comprehensive error scenarios and recovery.
|
||||
@ -779,7 +787,7 @@ class TestAddDocumentToIndexTask:
|
||||
document.indexing_status = "completed"
|
||||
document.error = None
|
||||
document.disabled_at = None
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set up Redis cache key
|
||||
indexing_cache_key = f"document_{document.id}_indexing"
|
||||
@ -789,7 +797,7 @@ class TestAddDocumentToIndexTask:
|
||||
add_document_to_index_task(document.id)
|
||||
|
||||
# Assert: Verify consistent error handling
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.enabled is False, f"Document should be disabled for {error_name}"
|
||||
assert document.indexing_status == "error", f"Document status should be error for {error_name}"
|
||||
assert document.error is not None, f"Error should be recorded for {error_name}"
|
||||
@ -798,7 +806,7 @@ class TestAddDocumentToIndexTask:
|
||||
|
||||
# Verify segments remain disabled due to error
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is False, f"Segments should remain disabled for {error_name}"
|
||||
|
||||
# Verify redis cache was still cleared despite error
|
||||
|
||||
@ -11,8 +11,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -49,7 +49,7 @@ class TestBatchCleanDocumentTask:
|
||||
"get_image_ids": mock_get_image_ids,
|
||||
}
|
||||
|
||||
def _create_test_account(self, db_session_with_containers):
|
||||
def _create_test_account(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Helper method to create a test account for testing.
|
||||
|
||||
@ -69,16 +69,16 @@ class TestBatchCleanDocumentTask:
|
||||
status="active",
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -87,15 +87,15 @@ class TestBatchCleanDocumentTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, account):
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, account):
|
||||
"""
|
||||
Helper method to create a test dataset for testing.
|
||||
|
||||
@ -119,12 +119,12 @@ class TestBatchCleanDocumentTask:
|
||||
embedding_model_provider="openai",
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, dataset, account):
|
||||
def _create_test_document(self, db_session_with_containers: Session, dataset, account):
|
||||
"""
|
||||
Helper method to create a test document for testing.
|
||||
|
||||
@ -153,12 +153,12 @@ class TestBatchCleanDocumentTask:
|
||||
doc_form="text_model",
|
||||
)
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
def _create_test_document_segment(self, db_session_with_containers, document, account):
|
||||
def _create_test_document_segment(self, db_session_with_containers: Session, document, account):
|
||||
"""
|
||||
Helper method to create a test document segment for testing.
|
||||
|
||||
@ -186,12 +186,12 @@ class TestBatchCleanDocumentTask:
|
||||
status="completed",
|
||||
)
|
||||
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return segment
|
||||
|
||||
def _create_test_upload_file(self, db_session_with_containers, account):
|
||||
def _create_test_upload_file(self, db_session_with_containers: Session, account):
|
||||
"""
|
||||
Helper method to create a test upload file for testing.
|
||||
|
||||
@ -220,13 +220,13 @@ class TestBatchCleanDocumentTask:
|
||||
used=False,
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
def test_batch_clean_document_task_successful_cleanup(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful cleanup of documents with segments and files.
|
||||
@ -245,7 +245,7 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Update document to reference the upload file
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Store original IDs for verification
|
||||
document_id = document.id
|
||||
@ -261,18 +261,18 @@ class TestBatchCleanDocumentTask:
|
||||
# The task should have processed the segment and cleaned up the database
|
||||
|
||||
# Verify database cleanup
|
||||
db.session.commit() # Ensure all changes are committed
|
||||
db_session_with_containers.commit() # Ensure all changes are committed
|
||||
|
||||
# Check that segment is deleted
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_with_image_files(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup of documents containing image references.
|
||||
@ -300,8 +300,8 @@ class TestBatchCleanDocumentTask:
|
||||
status="completed",
|
||||
)
|
||||
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Store original IDs for verification
|
||||
segment_id = segment.id
|
||||
@ -313,17 +313,17 @@ class TestBatchCleanDocumentTask:
|
||||
)
|
||||
|
||||
# Verify database cleanup
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that segment is deleted
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Verify that the task completed successfully by checking the log output
|
||||
# The task should have processed the segment and cleaned up the database
|
||||
|
||||
def test_batch_clean_document_task_no_segments(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup when document has no segments.
|
||||
@ -339,7 +339,7 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Update document to reference the upload file
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Store original IDs for verification
|
||||
document_id = document.id
|
||||
@ -354,21 +354,21 @@ class TestBatchCleanDocumentTask:
|
||||
# Since there are no segments, the task should handle this gracefully
|
||||
|
||||
# Verify database cleanup
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
# Verify database cleanup
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup when dataset is not found.
|
||||
@ -386,8 +386,8 @@ class TestBatchCleanDocumentTask:
|
||||
dataset_id = dataset.id
|
||||
|
||||
# Delete the dataset to simulate not found scenario
|
||||
db.session.delete(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.delete(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Execute the task with non-existent dataset
|
||||
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
|
||||
@ -399,14 +399,14 @@ class TestBatchCleanDocumentTask:
|
||||
mock_external_service_dependencies["storage"].delete.assert_not_called()
|
||||
|
||||
# Verify that no database cleanup occurred
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Document should still exist since cleanup failed
|
||||
existing_document = db.session.query(Document).filter_by(id=document_id).first()
|
||||
existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
|
||||
assert existing_document is not None
|
||||
|
||||
def test_batch_clean_document_task_storage_cleanup_failure(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup when storage operations fail.
|
||||
@ -423,7 +423,7 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Update document to reference the upload file
|
||||
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Store original IDs for verification
|
||||
document_id = document.id
|
||||
@ -442,18 +442,18 @@ class TestBatchCleanDocumentTask:
|
||||
# The task should continue processing even when storage operations fail
|
||||
|
||||
# Verify database cleanup still occurred despite storage failure
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that segment is deleted from database
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that upload file is deleted from database
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_multiple_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup of multiple documents in a single batch operation.
|
||||
@ -482,7 +482,7 @@ class TestBatchCleanDocumentTask:
|
||||
segments.append(segment)
|
||||
upload_files.append(upload_file)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Store original IDs for verification
|
||||
document_ids = [doc.id for doc in documents]
|
||||
@ -498,20 +498,20 @@ class TestBatchCleanDocumentTask:
|
||||
# The task should process all documents and clean up all associated resources
|
||||
|
||||
# Verify database cleanup for all resources
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that all segments are deleted
|
||||
for segment_id in segment_ids:
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that all upload files are deleted
|
||||
for file_id in file_ids:
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_different_doc_forms(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup with different document form types.
|
||||
@ -527,12 +527,12 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
for doc_form in doc_forms:
|
||||
dataset = self._create_test_dataset(db_session_with_containers, account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
document = self._create_test_document(db_session_with_containers, dataset, account)
|
||||
# Update document doc_form
|
||||
document.doc_form = doc_form
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_document_segment(db_session_with_containers, document, account)
|
||||
|
||||
@ -549,20 +549,20 @@ class TestBatchCleanDocumentTask:
|
||||
# The task should handle different document forms correctly
|
||||
|
||||
# Verify database cleanup
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that segment is deleted
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
except Exception as e:
|
||||
# If the task fails due to external service issues (e.g., plugin daemon),
|
||||
# we should still verify that the database state is consistent
|
||||
# This is a common scenario in test environments where external services may not be available
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check if the segment still exists (task may have failed before deletion)
|
||||
existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
if existing_segment is not None:
|
||||
# If segment still exists, the task failed before deletion
|
||||
# This is acceptable in test environments with external service issues
|
||||
@ -572,7 +572,7 @@ class TestBatchCleanDocumentTask:
|
||||
pass
|
||||
|
||||
def test_batch_clean_document_task_large_batch_performance(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test cleanup performance with a large batch of documents.
|
||||
@ -604,7 +604,7 @@ class TestBatchCleanDocumentTask:
|
||||
segments.append(segment)
|
||||
upload_files.append(upload_file)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Store original IDs for verification
|
||||
document_ids = [doc.id for doc in documents]
|
||||
@ -629,20 +629,20 @@ class TestBatchCleanDocumentTask:
|
||||
# The task should handle large batches efficiently
|
||||
|
||||
# Verify database cleanup for all resources
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that all segments are deleted
|
||||
for segment_id in segment_ids:
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that all upload files are deleted
|
||||
for file_id in file_ids:
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
def test_batch_clean_document_task_integration_with_real_database(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test full integration with real database operations.
|
||||
@ -683,12 +683,12 @@ class TestBatchCleanDocumentTask:
|
||||
|
||||
# Add all to database
|
||||
for segment in segments:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Verify initial state
|
||||
assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
|
||||
assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None
|
||||
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
|
||||
assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
|
||||
|
||||
# Store original IDs for verification
|
||||
document_id = document.id
|
||||
@ -704,17 +704,17 @@ class TestBatchCleanDocumentTask:
|
||||
# The task should process all segments and clean up all associated resources
|
||||
|
||||
# Verify database cleanup
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Check that all segments are deleted
|
||||
for segment_id in segment_ids:
|
||||
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
|
||||
assert deleted_segment is None
|
||||
|
||||
# Check that upload file is deleted
|
||||
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
|
||||
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
|
||||
assert deleted_file is None
|
||||
|
||||
# Verify final database state
|
||||
assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
|
||||
assert db.session.query(UploadFile).filter_by(id=file_id).first() is None
|
||||
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
|
||||
assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None
|
||||
|
||||
@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -29,20 +30,19 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
"""Integration tests for batch_create_segment_to_index_task using testcontainers."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
"""Clean up database before each test to ensure isolation."""
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
# Clear all test data
|
||||
db.session.query(DocumentSegment).delete()
|
||||
db.session.query(Document).delete()
|
||||
db.session.query(Dataset).delete()
|
||||
db.session.query(UploadFile).delete()
|
||||
db.session.query(TenantAccountJoin).delete()
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.query(Account).delete()
|
||||
db.session.commit()
|
||||
db_session_with_containers.query(DocumentSegment).delete()
|
||||
db_session_with_containers.query(Document).delete()
|
||||
db_session_with_containers.query(Dataset).delete()
|
||||
db_session_with_containers.query(UploadFile).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Clear Redis cache
|
||||
redis_client.flushdb()
|
||||
@ -75,7 +75,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
"embedding_model": mock_embedding_model,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -95,18 +95,16 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
status="active",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant for the account
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -115,15 +113,15 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, account, tenant):
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, account, tenant):
|
||||
"""
|
||||
Helper method to create a test dataset for testing.
|
||||
|
||||
@ -148,14 +146,12 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, account, tenant, dataset):
|
||||
def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset):
|
||||
"""
|
||||
Helper method to create a test document for testing.
|
||||
|
||||
@ -186,14 +182,12 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
word_count=0,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
def _create_test_upload_file(self, db_session_with_containers, account, tenant):
|
||||
def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant):
|
||||
"""
|
||||
Helper method to create a test upload file for testing.
|
||||
|
||||
@ -221,10 +215,8 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
used=False,
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(upload_file)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@ -252,7 +244,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
return csv_content
|
||||
|
||||
def test_batch_create_segment_to_index_task_success_text_model(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful batch creation of segments for text model documents.
|
||||
@ -293,11 +285,10 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
)
|
||||
|
||||
# Verify results
|
||||
from extensions.ext_database import db
|
||||
|
||||
# Check that segments were created
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter_by(document_id=document.id)
|
||||
.order_by(DocumentSegment.position)
|
||||
.all()
|
||||
@ -316,7 +307,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert segment.answer is None # text_model doesn't have answers
|
||||
|
||||
# Check that document word count was updated
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.word_count > 0
|
||||
|
||||
# Verify vector service was called
|
||||
@ -331,7 +322,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert cache_value == b"completed"
|
||||
|
||||
def test_batch_create_segment_to_index_task_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test task failure when dataset does not exist.
|
||||
@ -370,17 +361,16 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created (since dataset doesn't exist)
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify no documents were modified
|
||||
documents = db.session.query(Document).all()
|
||||
documents = db_session_with_containers.query(Document).all()
|
||||
assert len(documents) == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_document_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test task failure when document does not exist.
|
||||
@ -419,18 +409,17 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify dataset remains unchanged (no segments were added to the dataset)
|
||||
db.session.refresh(dataset)
|
||||
segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
db_session_with_containers.refresh(dataset)
|
||||
segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
|
||||
assert len(segments_for_dataset) == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_document_not_available(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test task failure when document is not available for indexing.
|
||||
@ -498,11 +487,9 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
),
|
||||
]
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for document in test_cases:
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Test each unavailable document
|
||||
for document in test_cases:
|
||||
@ -524,11 +511,11 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_upload_file_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test task failure when upload file does not exist.
|
||||
@ -567,17 +554,16 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert cache_value == b"error"
|
||||
|
||||
# Verify no segments were created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify document remains unchanged
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.word_count == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_empty_csv_file(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test task failure when CSV file is empty.
|
||||
@ -619,17 +605,16 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
|
||||
# Verify error handling
|
||||
# Since exception was raised, no segments should be created
|
||||
from extensions.ext_database import db
|
||||
|
||||
segments = db.session.query(DocumentSegment).all()
|
||||
segments = db_session_with_containers.query(DocumentSegment).all()
|
||||
assert len(segments) == 0
|
||||
|
||||
# Verify document remains unchanged
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.word_count == 0
|
||||
|
||||
def test_batch_create_segment_to_index_task_position_calculation(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test proper position calculation for segments when existing segments exist.
|
||||
@ -664,11 +649,9 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
)
|
||||
existing_segments.append(segment)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
for segment in existing_segments:
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create CSV content
|
||||
csv_content = self._create_test_csv_content("text_model")
|
||||
@ -695,7 +678,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
# Verify results
|
||||
# Check that new segments were created with correct positions
|
||||
all_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
db_session_with_containers.query(DocumentSegment)
|
||||
.filter_by(document_id=document.id)
|
||||
.order_by(DocumentSegment.position)
|
||||
.all()
|
||||
@ -716,7 +699,7 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
assert segment.completed_at is not None
|
||||
|
||||
# Check that document word count was updated
|
||||
db.session.refresh(document)
|
||||
db_session_with_containers.refresh(document)
|
||||
assert document.word_count > 0
|
||||
|
||||
# Verify vector service was called
|
||||
|
||||
@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import (
|
||||
@ -37,7 +38,7 @@ class TestCleanDatasetTask:
|
||||
"""Integration tests for clean_dataset_task using testcontainers."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers):
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
"""Clean up database before each test to ensure isolation."""
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
@ -82,7 +83,7 @@ class TestCleanDatasetTask:
|
||||
"index_processor": mock_index_processor,
|
||||
}
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers):
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -127,7 +128,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, account, tenant):
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, account, tenant):
|
||||
"""
|
||||
Helper method to create a test dataset for testing.
|
||||
|
||||
@ -157,7 +158,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, account, tenant, dataset):
|
||||
def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset):
|
||||
"""
|
||||
Helper method to create a test document for testing.
|
||||
|
||||
@ -194,7 +195,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
return document
|
||||
|
||||
def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document):
|
||||
def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document):
|
||||
"""
|
||||
Helper method to create a test document segment for testing.
|
||||
|
||||
@ -230,7 +231,7 @@ class TestCleanDatasetTask:
|
||||
|
||||
return segment
|
||||
|
||||
def _create_test_upload_file(self, db_session_with_containers, account, tenant):
|
||||
def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant):
|
||||
"""
|
||||
Helper method to create a test upload file for testing.
|
||||
|
||||
@ -264,7 +265,7 @@ class TestCleanDatasetTask:
|
||||
return upload_file
|
||||
|
||||
def test_clean_dataset_task_success_basic_cleanup(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful basic dataset cleanup with minimal data.
|
||||
@ -325,7 +326,7 @@ class TestCleanDatasetTask:
|
||||
mock_storage.delete.assert_not_called()
|
||||
|
||||
def test_clean_dataset_task_success_with_documents_and_segments(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful dataset cleanup with documents and segments.
|
||||
@ -433,7 +434,7 @@ class TestCleanDatasetTask:
|
||||
assert mock_storage.delete.call_count == 3
|
||||
|
||||
def test_clean_dataset_task_success_with_invalid_doc_form(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful dataset cleanup with invalid doc_form handling.
|
||||
@ -493,7 +494,7 @@ class TestCleanDatasetTask:
|
||||
assert mock_factory.call_count == 4
|
||||
|
||||
def test_clean_dataset_task_error_handling_and_rollback(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test error handling and rollback mechanism when database operations fail.
|
||||
@ -542,7 +543,7 @@ class TestCleanDatasetTask:
|
||||
# This demonstrates the resilience of the cleanup process
|
||||
|
||||
def test_clean_dataset_task_with_image_file_references(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test dataset cleanup with image file references in document segments.
|
||||
@ -634,7 +635,7 @@ class TestCleanDatasetTask:
|
||||
mock_get_image_ids.assert_called_once()
|
||||
|
||||
def test_clean_dataset_task_performance_with_large_dataset(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test dataset cleanup performance with large amounts of data.
|
||||
@ -704,11 +705,9 @@ class TestCleanDatasetTask:
|
||||
binding.created_at = datetime.now()
|
||||
bindings.append(binding)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add_all(metadata_items)
|
||||
db.session.add_all(bindings)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add_all(metadata_items)
|
||||
db_session_with_containers.add_all(bindings)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Measure cleanup performance
|
||||
import time
|
||||
@ -772,7 +771,7 @@ class TestCleanDatasetTask:
|
||||
print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds")
|
||||
|
||||
def test_clean_dataset_task_storage_exception_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test dataset cleanup when storage operations fail.
|
||||
@ -838,7 +837,7 @@ class TestCleanDatasetTask:
|
||||
# consistency in the database
|
||||
|
||||
def test_clean_dataset_task_edge_cases_and_boundary_conditions(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test dataset cleanup with edge cases and boundary conditions.
|
||||
|
||||
@ -13,8 +13,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -34,7 +34,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
mock_processor.clean.return_value = None
|
||||
yield mock_processor
|
||||
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]:
|
||||
def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]:
|
||||
"""
|
||||
Helper method to create a test account and tenant for testing.
|
||||
|
||||
@ -53,8 +53,8 @@ class TestDisableSegmentFromIndexTask:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
@ -62,8 +62,8 @@ class TestDisableSegmentFromIndexTask:
|
||||
status="normal",
|
||||
plan="basic",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join with owner role
|
||||
join = TenantAccountJoin(
|
||||
@ -72,15 +72,15 @@ class TestDisableSegmentFromIndexTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set current tenant for account
|
||||
account.current_tenant = tenant
|
||||
|
||||
return account, tenant
|
||||
|
||||
def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset:
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset:
|
||||
"""
|
||||
Helper method to create a test dataset.
|
||||
|
||||
@ -101,13 +101,18 @@ class TestDisableSegmentFromIndexTask:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(
|
||||
self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model"
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
dataset: Dataset,
|
||||
tenant: Tenant,
|
||||
account: Account,
|
||||
doc_form: str = "text_model",
|
||||
) -> Document:
|
||||
"""
|
||||
Helper method to create a test document.
|
||||
@ -140,13 +145,14 @@ class TestDisableSegmentFromIndexTask:
|
||||
tokens=500,
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return document
|
||||
|
||||
def _create_test_segment(
|
||||
self,
|
||||
db_session_with_containers: Session,
|
||||
document: Document,
|
||||
dataset: Dataset,
|
||||
tenant: Tenant,
|
||||
@ -185,12 +191,12 @@ class TestDisableSegmentFromIndexTask:
|
||||
created_by=account.id,
|
||||
completed_at=datetime.now(UTC) if status == "completed" else None,
|
||||
)
|
||||
db.session.add(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return segment
|
||||
|
||||
def test_disable_segment_success(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test successful segment disabling from index.
|
||||
|
||||
@ -202,9 +208,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Set up Redis cache
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
@ -226,10 +232,10 @@ class TestDisableSegmentFromIndexTask:
|
||||
assert redis_client.get(indexing_cache_key) is None
|
||||
|
||||
# Verify segment is still in database
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.id is not None
|
||||
|
||||
def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when segment is not found.
|
||||
|
||||
@ -251,7 +257,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when segment is not in completed status.
|
||||
|
||||
@ -262,9 +268,11 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data with non-completed segment
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(
|
||||
db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True
|
||||
)
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -275,7 +283,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when segment has no associated dataset.
|
||||
|
||||
@ -286,13 +294,13 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Manually remove dataset association
|
||||
segment.dataset_id = "00000000-0000-0000-0000-000000000000"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -303,7 +311,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when segment has no associated document.
|
||||
|
||||
@ -314,13 +322,13 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Manually remove document association
|
||||
segment.document_id = "00000000-0000-0000-0000-000000000000"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -331,7 +339,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when document is disabled.
|
||||
|
||||
@ -342,12 +350,12 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data with disabled document
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
document.enabled = False
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -358,7 +366,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when document is archived.
|
||||
|
||||
@ -369,12 +377,12 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data with archived document
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
document.archived = True
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -385,7 +393,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_document_indexing_not_completed(
|
||||
self, db_session_with_containers: Session, mock_index_processor
|
||||
):
|
||||
"""
|
||||
Test handling when document indexing is not completed.
|
||||
|
||||
@ -396,12 +406,12 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data with incomplete indexing
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
document.indexing_status = "indexing"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -412,7 +422,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
# Verify index processor was not called
|
||||
mock_index_processor.clean.assert_not_called()
|
||||
|
||||
def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test handling when index processor raises an exception.
|
||||
|
||||
@ -424,9 +434,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Set up Redis cache
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
@ -449,13 +459,13 @@ class TestDisableSegmentFromIndexTask:
|
||||
assert call_args[0][1] == [segment.index_node_id] # Check index node IDs
|
||||
|
||||
# Verify segment was re-enabled
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is True
|
||||
|
||||
# Verify Redis cache was still cleared
|
||||
assert redis_client.get(indexing_cache_key) is None
|
||||
|
||||
def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test disabling segments with different document forms.
|
||||
|
||||
@ -470,9 +480,11 @@ class TestDisableSegmentFromIndexTask:
|
||||
for doc_form in doc_forms:
|
||||
# Arrange: Create test data for each form
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account, doc_form=doc_form)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(
|
||||
db_session_with_containers, dataset, tenant, account, doc_form=doc_form
|
||||
)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Reset mock for each iteration
|
||||
mock_index_processor.reset_mock()
|
||||
@ -489,7 +501,7 @@ class TestDisableSegmentFromIndexTask:
|
||||
assert call_args[0][0].id == dataset.id # Check dataset ID
|
||||
assert call_args[0][1] == [segment.index_node_id] # Check index node IDs
|
||||
|
||||
def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test Redis cache handling during segment disabling.
|
||||
|
||||
@ -500,9 +512,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Test with cache present
|
||||
indexing_cache_key = f"segment_{segment.id}_indexing"
|
||||
@ -517,13 +529,13 @@ class TestDisableSegmentFromIndexTask:
|
||||
assert redis_client.get(indexing_cache_key) is None
|
||||
|
||||
# Test with no cache present
|
||||
segment2 = self._create_test_segment(document, dataset, tenant, account)
|
||||
segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
result2 = disable_segment_from_index_task(segment2.id)
|
||||
|
||||
# Assert: Verify task still works without cache
|
||||
assert result2 is None
|
||||
|
||||
def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test performance timing of segment disabling task.
|
||||
|
||||
@ -534,9 +546,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Act: Execute the task and measure time
|
||||
start_time = time.perf_counter()
|
||||
@ -548,7 +560,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
execution_time = end_time - start_time
|
||||
assert execution_time < 5.0 # Should complete within 5 seconds
|
||||
|
||||
def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_database_session_management(
|
||||
self, db_session_with_containers: Session, mock_index_processor
|
||||
):
|
||||
"""
|
||||
Test database session management during task execution.
|
||||
|
||||
@ -559,9 +573,9 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create test data
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
|
||||
# Act: Execute the task
|
||||
result = disable_segment_from_index_task(segment.id)
|
||||
@ -570,10 +584,10 @@ class TestDisableSegmentFromIndexTask:
|
||||
assert result is None
|
||||
|
||||
# Verify segment is still accessible (session was properly managed)
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.id is not None
|
||||
|
||||
def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor):
|
||||
def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor):
|
||||
"""
|
||||
Test concurrent execution of segment disabling tasks.
|
||||
|
||||
@ -584,12 +598,12 @@ class TestDisableSegmentFromIndexTask:
|
||||
"""
|
||||
# Arrange: Create multiple test segments
|
||||
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
|
||||
dataset = self._create_test_dataset(tenant, account)
|
||||
document = self._create_test_document(dataset, tenant, account)
|
||||
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
|
||||
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
|
||||
|
||||
segments = []
|
||||
for i in range(3):
|
||||
segment = self._create_test_segment(document, dataset, tenant, account)
|
||||
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
|
||||
segments.append(segment)
|
||||
|
||||
# Act: Execute tasks concurrently (simulated)
|
||||
|
||||
@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import Account, Dataset, DocumentSegment
|
||||
from models import Document as DatasetDocument
|
||||
@ -31,7 +32,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
and realistic testing environment with actual database interactions.
|
||||
"""
|
||||
|
||||
def _create_test_account(self, db_session_with_containers, fake=None):
|
||||
def _create_test_account(self, db_session_with_containers: Session, fake=None):
|
||||
"""
|
||||
Helper method to create a test account with realistic data.
|
||||
|
||||
@ -79,7 +80,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
return account
|
||||
|
||||
def _create_test_dataset(self, db_session_with_containers, account, fake=None):
|
||||
def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None):
|
||||
"""
|
||||
Helper method to create a test dataset with realistic data.
|
||||
|
||||
@ -113,7 +114,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
return dataset
|
||||
|
||||
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None):
|
||||
def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None):
|
||||
"""
|
||||
Helper method to create a test document with realistic data.
|
||||
|
||||
@ -158,7 +159,9 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
return document
|
||||
|
||||
def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None):
|
||||
def _create_test_segments(
|
||||
self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None
|
||||
):
|
||||
"""
|
||||
Helper method to create test document segments with realistic data.
|
||||
|
||||
@ -210,7 +213,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
return segments
|
||||
|
||||
def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None):
|
||||
def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None):
|
||||
"""
|
||||
Helper method to create a dataset process rule.
|
||||
|
||||
@ -239,14 +242,12 @@ class TestDisableSegmentsFromIndexTask:
|
||||
process_rule.created_by = dataset.created_by
|
||||
process_rule.updated_by = dataset.updated_by
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(process_rule)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(process_rule)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return process_rule
|
||||
|
||||
def test_disable_segments_success(self, db_session_with_containers):
|
||||
def test_disable_segments_success(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test successful disabling of segments from index.
|
||||
|
||||
@ -297,7 +298,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
expected_key = f"segment_{segment.id}_indexing"
|
||||
mock_redis.delete.assert_any_call(expected_key)
|
||||
|
||||
def test_disable_segments_dataset_not_found(self, db_session_with_containers):
|
||||
def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when dataset is not found.
|
||||
|
||||
@ -320,7 +321,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Redis should not be called when dataset is not found
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_document_not_found(self, db_session_with_containers):
|
||||
def test_disable_segments_document_not_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when document is not found.
|
||||
|
||||
@ -344,7 +345,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Redis should not be called when document is not found
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_document_invalid_status(self, db_session_with_containers):
|
||||
def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when document has invalid status for disabling.
|
||||
|
||||
@ -360,9 +361,8 @@ class TestDisableSegmentsFromIndexTask:
|
||||
|
||||
# Test case 1: Document not enabled
|
||||
document.enabled = False
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
|
||||
@ -379,7 +379,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Test case 2: Document archived
|
||||
document.enabled = True
|
||||
document.archived = True
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
@ -393,7 +393,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.indexing_status = "indexing"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
|
||||
# Act
|
||||
@ -403,7 +403,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
assert result is None # Task should complete without returning a value
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_no_segments_found(self, db_session_with_containers):
|
||||
def test_disable_segments_no_segments_found(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when no segments are found for the given IDs.
|
||||
|
||||
@ -430,7 +430,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Redis should not be called when no segments are found
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_index_processor_error(self, db_session_with_containers):
|
||||
def test_disable_segments_index_processor_error(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when index processor encounters an error.
|
||||
|
||||
@ -464,13 +464,14 @@ class TestDisableSegmentsFromIndexTask:
|
||||
assert result is None # Task should complete without returning a value
|
||||
|
||||
# Verify segments were rolled back to enabled state
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.refresh(segments[0])
|
||||
db.session.refresh(segments[1])
|
||||
db_session_with_containers.refresh(segments[0])
|
||||
db_session_with_containers.refresh(segments[1])
|
||||
|
||||
# Check that segments are re-enabled after error
|
||||
updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
|
||||
updated_segments = (
|
||||
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
|
||||
)
|
||||
|
||||
for segment in updated_segments:
|
||||
assert segment.enabled is True
|
||||
@ -480,7 +481,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Verify Redis cache cleanup was still called
|
||||
assert mock_redis.delete.call_count == len(segments)
|
||||
|
||||
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers):
|
||||
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test disabling segments with different document forms.
|
||||
|
||||
@ -503,9 +504,8 @@ class TestDisableSegmentsFromIndexTask:
|
||||
for doc_form in doc_forms:
|
||||
# Update document form
|
||||
document.doc_form = doc_form
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the index processor factory
|
||||
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
|
||||
@ -523,7 +523,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
assert result is None # Task should complete without returning a value
|
||||
mock_factory.assert_called_with(doc_form)
|
||||
|
||||
def test_disable_segments_performance_timing(self, db_session_with_containers):
|
||||
def test_disable_segments_performance_timing(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that the task properly measures and logs performance timing.
|
||||
|
||||
@ -568,7 +568,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
assert performance_log is not None
|
||||
assert "0.5" in performance_log # Should log the execution time
|
||||
|
||||
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers):
|
||||
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that Redis cache is properly cleaned up for all segments.
|
||||
|
||||
@ -610,7 +610,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
for expected_key in expected_keys:
|
||||
assert expected_key in actual_calls
|
||||
|
||||
def test_disable_segments_database_session_cleanup(self, db_session_with_containers):
|
||||
def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test that database session is properly closed after task execution.
|
||||
|
||||
@ -643,7 +643,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
assert result is None # Task should complete without returning a value
|
||||
# Session lifecycle is managed by context manager; no explicit close assertion
|
||||
|
||||
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
|
||||
def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when empty segment IDs list is provided.
|
||||
|
||||
@ -669,7 +669,7 @@ class TestDisableSegmentsFromIndexTask:
|
||||
# Redis should not be called when no segments are provided
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers):
|
||||
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Test handling when some segment IDs are valid and others are invalid.
|
||||
|
||||
|
||||
@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -31,7 +31,9 @@ class TestEnableSegmentsToIndexTask:
|
||||
"index_processor": mock_processor,
|
||||
}
|
||||
|
||||
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def _create_test_dataset_and_document(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Helper method to create a test dataset and document for testing.
|
||||
|
||||
@ -51,15 +53,15 @@ class TestEnableSegmentsToIndexTask:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -68,8 +70,8 @@ class TestEnableSegmentsToIndexTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
@ -81,8 +83,8 @@ class TestEnableSegmentsToIndexTask:
|
||||
indexing_technique="high_quality",
|
||||
created_by=account.id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create document
|
||||
document = Document(
|
||||
@ -99,16 +101,16 @@ class TestEnableSegmentsToIndexTask:
|
||||
enabled=True,
|
||||
doc_form=IndexStructureType.PARAGRAPH_INDEX,
|
||||
)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property works correctly
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
return dataset, document
|
||||
|
||||
def _create_test_segments(
|
||||
self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed"
|
||||
self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed"
|
||||
):
|
||||
"""
|
||||
Helper method to create test document segments.
|
||||
@ -144,14 +146,14 @@ class TestEnableSegmentsToIndexTask:
|
||||
status=status,
|
||||
created_by=document.created_by,
|
||||
)
|
||||
db.session.add(segment)
|
||||
db_session_with_containers.add(segment)
|
||||
segments.append(segment)
|
||||
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
return segments
|
||||
|
||||
def test_enable_segments_to_index_with_different_index_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test segments indexing with different index types.
|
||||
@ -169,10 +171,10 @@ class TestEnableSegmentsToIndexTask:
|
||||
|
||||
# Update document to use different index type
|
||||
document.doc_form = IndexStructureType.QA_INDEX
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
# Create segments
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
@ -204,7 +206,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_enable_segments_to_index_dataset_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent dataset.
|
||||
@ -229,7 +231,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_enable_segments_to_index_document_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of non-existent document.
|
||||
@ -256,7 +258,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_enable_segments_to_index_invalid_document_status(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling of document with invalid status.
|
||||
@ -284,12 +286,12 @@ class TestEnableSegmentsToIndexTask:
|
||||
document.enabled = True
|
||||
document.archived = False
|
||||
document.indexing_status = "completed"
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Set invalid status
|
||||
for attr, value in status_attrs.items():
|
||||
setattr(document, attr, value)
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create segments
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
@ -304,11 +306,11 @@ class TestEnableSegmentsToIndexTask:
|
||||
|
||||
# Clean up segments for next iteration
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
db_session_with_containers.delete(segment)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
def test_enable_segments_to_index_segments_not_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test handling when no segments are found.
|
||||
@ -338,7 +340,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
mock_external_service_dependencies["index_processor"].load.assert_not_called()
|
||||
|
||||
def test_enable_segments_to_index_with_parent_child_structure(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test segments indexing with parent-child structure.
|
||||
@ -357,10 +359,10 @@ class TestEnableSegmentsToIndexTask:
|
||||
|
||||
# Update document to use parent-child index type
|
||||
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
|
||||
db.session.commit()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh dataset to ensure doc_form property reflects the updated document
|
||||
db.session.refresh(dataset)
|
||||
db_session_with_containers.refresh(dataset)
|
||||
|
||||
# Create segments with mock child chunks
|
||||
segments = self._create_test_segments(db_session_with_containers, document, dataset)
|
||||
@ -410,7 +412,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
assert redis_client.exists(indexing_cache_key) == 0
|
||||
|
||||
def test_enable_segments_to_index_general_exception_handling(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test general exception handling during indexing process.
|
||||
@ -443,7 +445,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
|
||||
# Assert: Verify error handling
|
||||
for segment in segments:
|
||||
db.session.refresh(segment)
|
||||
db_session_with_containers.refresh(segment)
|
||||
assert segment.enabled is False
|
||||
assert segment.status == "error"
|
||||
assert segment.error is not None
|
||||
|
||||
@ -2,8 +2,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.email_i18n import EmailType
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
|
||||
@ -30,7 +30,7 @@ class TestMailAccountDeletionTask:
|
||||
"email_service": mock_email_service,
|
||||
}
|
||||
|
||||
def _create_test_account(self, db_session_with_containers):
|
||||
def _create_test_account(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Helper method to create a test account for testing.
|
||||
|
||||
@ -49,16 +49,16 @@ class TestMailAccountDeletionTask:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -67,12 +67,14 @@ class TestMailAccountDeletionTask:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
return account
|
||||
|
||||
def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
def test_send_deletion_success_task_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful account deletion success email sending.
|
||||
|
||||
@ -109,7 +111,7 @@ class TestMailAccountDeletionTask:
|
||||
)
|
||||
|
||||
def test_send_deletion_success_task_mail_not_initialized(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion success email when mail service is not initialized.
|
||||
@ -132,7 +134,7 @@ class TestMailAccountDeletionTask:
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
|
||||
|
||||
def test_send_deletion_success_task_email_service_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion success email when email service raises exception.
|
||||
@ -154,7 +156,7 @@ class TestMailAccountDeletionTask:
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_called_once()
|
||||
|
||||
def test_send_account_deletion_verification_code_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful account deletion verification code email sending.
|
||||
@ -193,7 +195,7 @@ class TestMailAccountDeletionTask:
|
||||
)
|
||||
|
||||
def test_send_account_deletion_verification_code_mail_not_initialized(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion verification code email when mail service is not initialized.
|
||||
@ -217,7 +219,7 @@ class TestMailAccountDeletionTask:
|
||||
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
|
||||
|
||||
def test_send_account_deletion_verification_code_email_service_exception(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test account deletion verification code email when email service raises exception.
|
||||
|
||||
@ -4,11 +4,11 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
|
||||
from extensions.ext_database import db
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Pipeline
|
||||
from models.workflow import Workflow
|
||||
@ -52,7 +52,7 @@ class TestRagPipelineRunTasks:
|
||||
"delete_file": mock_delete_file,
|
||||
}
|
||||
|
||||
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
|
||||
def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session):
|
||||
"""
|
||||
Helper method to create test pipeline and workflow for testing.
|
||||
|
||||
@ -71,15 +71,15 @@ class TestRagPipelineRunTasks:
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(
|
||||
name=fake.company(),
|
||||
status="normal",
|
||||
)
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create tenant-account join
|
||||
join = TenantAccountJoin(
|
||||
@ -88,8 +88,8 @@ class TestRagPipelineRunTasks:
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(join)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create workflow
|
||||
workflow = Workflow(
|
||||
@ -107,8 +107,8 @@ class TestRagPipelineRunTasks:
|
||||
conversation_variables=[],
|
||||
rag_pipeline_variables=[],
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(workflow)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(
|
||||
@ -119,14 +119,14 @@ class TestRagPipelineRunTasks:
|
||||
created_by=account.id,
|
||||
)
|
||||
pipeline.id = str(uuid.uuid4())
|
||||
db.session.add(pipeline)
|
||||
db.session.commit()
|
||||
db_session_with_containers.add(pipeline)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Refresh entities to ensure they're properly loaded
|
||||
db.session.refresh(account)
|
||||
db.session.refresh(tenant)
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(pipeline)
|
||||
db_session_with_containers.refresh(account)
|
||||
db_session_with_containers.refresh(tenant)
|
||||
db_session_with_containers.refresh(workflow)
|
||||
db_session_with_containers.refresh(pipeline)
|
||||
|
||||
return account, tenant, pipeline, workflow
|
||||
|
||||
@ -209,7 +209,7 @@ class TestRagPipelineRunTasks:
|
||||
return json.dumps(entities_data)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful priority RAG pipeline run task execution.
|
||||
@ -254,7 +254,7 @@ class TestRagPipelineRunTasks:
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_rag_pipeline_run_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test successful regular RAG pipeline run task execution.
|
||||
@ -299,7 +299,7 @@ class TestRagPipelineRunTasks:
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
@ -351,7 +351,7 @@ class TestRagPipelineRunTasks:
|
||||
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
|
||||
|
||||
def test_rag_pipeline_run_task_legacy_compatibility(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
|
||||
@ -419,7 +419,7 @@ class TestRagPipelineRunTasks:
|
||||
redis_client.delete(legacy_task_key)
|
||||
|
||||
def test_rag_pipeline_run_task_with_waiting_tasks(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
|
||||
@ -469,7 +469,7 @@ class TestRagPipelineRunTasks:
|
||||
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
|
||||
|
||||
def test_priority_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in priority RAG pipeline run task using real Redis.
|
||||
@ -526,7 +526,7 @@ class TestRagPipelineRunTasks:
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_error_handling(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test error handling in regular RAG pipeline run task using real Redis.
|
||||
@ -581,7 +581,7 @@ class TestRagPipelineRunTasks:
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_priority_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in priority RAG pipeline run task using real Redis.
|
||||
@ -648,7 +648,7 @@ class TestRagPipelineRunTasks:
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_rag_pipeline_run_task_tenant_isolation(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test tenant isolation in regular RAG pipeline run task using real Redis.
|
||||
@ -713,7 +713,7 @@ class TestRagPipelineRunTasks:
|
||||
assert queue1._task_key != queue2._task_key
|
||||
|
||||
def test_run_single_rag_pipeline_task_success(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test successful run_single_rag_pipeline_task execution.
|
||||
@ -748,7 +748,7 @@ class TestRagPipelineRunTasks:
|
||||
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
|
||||
|
||||
def test_run_single_rag_pipeline_task_entity_validation_error(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with invalid entity data.
|
||||
@ -793,7 +793,7 @@ class TestRagPipelineRunTasks:
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_run_single_rag_pipeline_task_database_entity_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
|
||||
):
|
||||
"""
|
||||
Test run_single_rag_pipeline_task with non-existent database entities.
|
||||
@ -838,7 +838,7 @@ class TestRagPipelineRunTasks:
|
||||
mock_pipeline_generator.assert_not_called()
|
||||
|
||||
def test_priority_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test priority RAG pipeline run task with non-existent file.
|
||||
@ -888,7 +888,7 @@ class TestRagPipelineRunTasks:
|
||||
assert len(remaining_tasks) == 0
|
||||
|
||||
def test_rag_pipeline_run_task_file_not_found(
|
||||
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
|
||||
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
|
||||
):
|
||||
"""
|
||||
Test regular RAG pipeline run task with non-existent file.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user