feat: replace db.session with db_session_with_containers (#32942)

This commit is contained in:
Renzo 2026-03-03 21:50:41 -08:00 committed by GitHub
parent 2f4c740d46
commit ad000c42b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 3078 additions and 2669 deletions

View File

@ -9,8 +9,8 @@ from itertools import starmap
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from services.dataset_service import DatasetCollectionBindingService
@ -28,6 +28,7 @@ class DatasetCollectionBindingTestDataFactory:
@staticmethod
def create_collection_binding(
db_session_with_containers: Session,
provider_name: str = "openai",
model_name: str = "text-embedding-ada-002",
collection_name: str = "collection-abc",
@ -51,8 +52,8 @@ class DatasetCollectionBindingTestDataFactory:
collection_name=collection_name,
type=collection_type,
)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(binding)
db_session_with_containers.commit()
return binding
@ -64,7 +65,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
including various provider/model combinations, collection types, and edge cases.
"""
def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers):
def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session):
"""
Test successful retrieval of an existing collection binding.
@ -77,6 +78,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
model_name = "text-embedding-ada-002"
collection_type = "dataset"
existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name=provider_name,
model_name=model_name,
collection_name="existing-collection",
@ -92,7 +94,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.id == existing_binding.id
assert result.collection_name == "existing-collection"
def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers):
def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session):
"""
Test successful creation of a new collection binding when none exists.
@ -116,7 +118,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.type == collection_type
assert result.collection_name is not None
def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers):
def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session):
"""Test get_dataset_collection_binding with different collection type."""
# Arrange
provider_name = "openai"
@ -133,7 +135,7 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.provider_name == provider_name
assert result.model_name == model_name
def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers):
def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session):
"""Test get_dataset_collection_binding with default collection type parameter."""
# Arrange
provider_name = "openai"
@ -147,7 +149,9 @@ class TestDatasetCollectionBindingServiceGetBinding:
assert result.provider_name == provider_name
assert result.model_name == model_name
def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers):
def test_get_dataset_collection_binding_different_provider_model_combination(
self, db_session_with_containers: Session
):
"""Test get_dataset_collection_binding with various provider/model combinations."""
# Arrange
combinations = [
@ -174,10 +178,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
including successful retrieval and error handling for missing bindings.
"""
def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session):
"""Test successful retrieval of collection binding by ID and type."""
# Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
@ -194,7 +199,7 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
assert result.collection_name == "test-collection"
assert result.type == "dataset"
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session):
"""Test error handling when collection binding is not found by ID and type."""
# Arrange
non_existent_id = str(uuid4())
@ -203,10 +208,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset")
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(
self, db_session_with_containers: Session
):
"""Test retrieval by ID and type with different collection type."""
# Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
@ -222,10 +230,13 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
assert result.id == binding.id
assert result.type == "custom_type"
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(
self, db_session_with_containers: Session
):
"""Test retrieval by ID with default collection type."""
# Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",
@ -239,10 +250,11 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType:
assert result.id == binding.id
assert result.type == "dataset"
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers):
def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session):
"""Test error when binding exists but with wrong collection type."""
# Arrange
binding = DatasetCollectionBindingTestDataFactory.create_collection_binding(
db_session_with_containers,
provider_name="openai",
model_name="text-embedding-ada-002",
collection_name="test-collection",

View File

@ -10,9 +10,9 @@ from unittest.mock import patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum
from models.model import App
@ -27,6 +27,7 @@ class DatasetUpdateDeleteTestDataFactory:
@staticmethod
def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.NORMAL,
tenant: Tenant | None = None,
) -> tuple[Account, Tenant]:
@ -37,13 +38,13 @@ class DatasetUpdateDeleteTestDataFactory:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
@ -51,14 +52,15 @@ class DatasetUpdateDeleteTestDataFactory:
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_dataset(
db_session_with_containers: Session,
tenant_id: str,
created_by: str,
name: str = "Test Dataset",
@ -78,12 +80,12 @@ class DatasetUpdateDeleteTestDataFactory:
retrieval_model={"top_k": 2},
enable_api=enable_api,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App:
def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App:
"""Create a real app for AppDatasetJoin."""
app = App(
tenant_id=tenant_id,
@ -96,16 +98,16 @@ class DatasetUpdateDeleteTestDataFactory:
enable_api=True,
created_by=created_by,
)
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
@staticmethod
def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin:
def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin:
"""Create a real AppDatasetJoin record."""
join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
return join
@ -114,7 +116,7 @@ class TestDatasetServiceDeleteDataset:
Comprehensive integration tests for DatasetService.delete_dataset method.
"""
def test_delete_dataset_success(self, db_session_with_containers):
def test_delete_dataset_success(self, db_session_with_containers: Session):
"""
Test successful deletion of a dataset.
@ -130,8 +132,10 @@ class TestDatasetServiceDeleteDataset:
- Method returns True
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
# Act
with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted:
@ -139,10 +143,10 @@ class TestDatasetServiceDeleteDataset:
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
assert db_session_with_containers.get(Dataset, dataset.id) is None
mock_dataset_was_deleted.send.assert_called_once_with(dataset)
def test_delete_dataset_not_found(self, db_session_with_containers):
def test_delete_dataset_not_found(self, db_session_with_containers: Session):
"""
Test handling when dataset is not found.
@ -156,7 +160,9 @@ class TestDatasetServiceDeleteDataset:
- No database operations are performed
"""
# Arrange
owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset_id = str(uuid4())
# Act
@ -165,7 +171,7 @@ class TestDatasetServiceDeleteDataset:
# Assert
assert result is False
def test_delete_dataset_permission_denied_error(self, db_session_with_containers):
def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session):
"""
Test error handling when user lacks permission.
@ -178,19 +184,22 @@ class TestDatasetServiceDeleteDataset:
- No database operations are performed
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers,
role=TenantAccountRole.NORMAL,
tenant=tenant,
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
# Act & Assert
with pytest.raises(NoPermissionError):
DatasetService.delete_dataset(dataset.id, normal_user)
# Verify no deletion was attempted
assert db.session.get(Dataset, dataset.id) is not None
assert db_session_with_containers.get(Dataset, dataset.id) is not None
class TestDatasetServiceDatasetUseCheck:
@ -198,7 +207,7 @@ class TestDatasetServiceDatasetUseCheck:
Comprehensive integration tests for DatasetService.dataset_use_check method.
"""
def test_dataset_use_check_in_use(self, db_session_with_containers):
def test_dataset_use_check_in_use(self, db_session_with_containers: Session):
"""
Test detection when dataset is in use.
@ -211,10 +220,12 @@ class TestDatasetServiceDatasetUseCheck:
- Database query is executed
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id)
DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id)
DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id)
# Act
result = DatasetService.dataset_use_check(dataset.id)
@ -222,7 +233,7 @@ class TestDatasetServiceDatasetUseCheck:
# Assert
assert result is True
def test_dataset_use_check_not_in_use(self, db_session_with_containers):
def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session):
"""
Test detection when dataset is not in use.
@ -235,8 +246,10 @@ class TestDatasetServiceDatasetUseCheck:
- Database query is executed
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
# Act
result = DatasetService.dataset_use_check(dataset.id)
@ -250,7 +263,7 @@ class TestDatasetServiceUpdateDatasetApiStatus:
Comprehensive integration tests for DatasetService.update_dataset_api_status method.
"""
def test_update_dataset_api_status_enable_success(self, db_session_with_containers):
def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session):
"""
Test successful enabling of dataset API access.
@ -264,8 +277,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
- Transaction is committed
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
db_session_with_containers, tenant.id, owner.id, enable_api=False
)
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
# Act
@ -276,12 +293,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
DatasetService.update_dataset_api_status(dataset.id, True)
# Assert
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.enable_api is True
assert dataset.updated_by == owner.id
assert dataset.updated_at == current_time
def test_update_dataset_api_status_disable_success(self, db_session_with_containers):
def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session):
"""
Test successful disabling of dataset API access.
@ -295,8 +312,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
- Transaction is committed
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
db_session_with_containers, tenant.id, owner.id, enable_api=True
)
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
# Act
@ -307,11 +328,11 @@ class TestDatasetServiceUpdateDatasetApiStatus:
DatasetService.update_dataset_api_status(dataset.id, False)
# Assert
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.enable_api is False
assert dataset.updated_by == owner.id
def test_update_dataset_api_status_not_found_error(self, db_session_with_containers):
def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session):
"""
Test error handling when dataset is not found.
@ -330,7 +351,7 @@ class TestDatasetServiceUpdateDatasetApiStatus:
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status(dataset_id, True)
def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers):
def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session):
"""
Test error handling when current_user is missing.
@ -343,8 +364,12 @@ class TestDatasetServiceUpdateDatasetApiStatus:
- No updates are committed
"""
# Arrange
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False)
owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(
db_session_with_containers, tenant.id, owner.id, enable_api=False
)
# Act & Assert
with (
@ -354,6 +379,6 @@ class TestDatasetServiceUpdateDatasetApiStatus:
DatasetService.update_dataset_api_status(dataset.id, True)
# Verify no commit was attempted
db.session.rollback()
db.session.refresh(dataset)
db_session_with_containers.rollback()
db_session_with_containers.refresh(dataset)
assert dataset.enable_api is False

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
@ -87,7 +88,7 @@ class TestAgentService:
"account_feature_service": mock_account_feature_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -133,13 +134,12 @@ class TestAgentService:
# Update the app model config to set agent_mode for agent-chat mode
if app.mode == "agent-chat" and app.app_model_config:
app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
return app, account
def _create_test_conversation_and_message(self, db_session_with_containers, app, account):
def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account):
"""
Helper method to create a test conversation and message with agent thoughts.
@ -153,8 +153,6 @@ class TestAgentService:
"""
fake = Faker()
from extensions.ext_database import db
# Create conversation
conversation = Conversation(
id=fake.uuid4(),
@ -167,8 +165,8 @@ class TestAgentService:
mode="chat",
from_source="api",
)
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
# Create app model config
app_model_config = AppModelConfig(
@ -180,12 +178,12 @@ class TestAgentService:
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
)
app_model_config.id = fake.uuid4()
db.session.add(app_model_config)
db.session.commit()
db_session_with_containers.add(app_model_config)
db_session_with_containers.commit()
# Update conversation with app model config
conversation.app_model_config_id = app_model_config.id
db.session.commit()
db_session_with_containers.commit()
# Create message
message = Message(
@ -206,12 +204,12 @@ class TestAgentService:
currency="USD",
from_source="api",
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
return conversation, message
def _create_test_agent_thoughts(self, db_session_with_containers, message):
def _create_test_agent_thoughts(self, db_session_with_containers: Session, message):
"""
Helper method to create test agent thoughts for a message.
@ -224,8 +222,6 @@ class TestAgentService:
"""
fake = Faker()
from extensions.ext_database import db
agent_thoughts = []
# Create first agent thought
@ -251,7 +247,7 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(thought1)
db_session_with_containers.add(thought1)
agent_thoughts.append(thought1)
# Create second agent thought
@ -277,14 +273,14 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(thought2)
db_session_with_containers.add(thought2)
agent_thoughts.append(thought2)
db.session.commit()
db_session_with_containers.commit()
return agent_thoughts
def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of agent logs with complete data.
"""
@ -344,7 +340,7 @@ class TestAgentService:
assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon
def test_get_agent_logs_conversation_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when conversation is not found.
@ -358,7 +354,9 @@ class TestAgentService:
with pytest.raises(ValueError, match="Conversation not found"):
AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4())
def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_message_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when message is not found.
"""
@ -372,7 +370,9 @@ class TestAgentService:
with pytest.raises(ValueError, match="Message not found"):
AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4())
def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_with_end_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval when conversation is from end user.
"""
@ -381,8 +381,6 @@ class TestAgentService:
# Create test data
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create end user
end_user = EndUser(
id=fake.uuid4(),
@ -393,8 +391,8 @@ class TestAgentService:
session_id=fake.uuid4(),
name=fake.name(),
)
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
# Create conversation with end user
conversation = Conversation(
@ -408,8 +406,8 @@ class TestAgentService:
mode="chat",
from_source="api",
)
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
# Create app model config
app_model_config = AppModelConfig(
@ -421,12 +419,12 @@ class TestAgentService:
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
)
app_model_config.id = fake.uuid4()
db.session.add(app_model_config)
db.session.commit()
db_session_with_containers.add(app_model_config)
db_session_with_containers.commit()
# Update conversation with app model config
conversation.app_model_config_id = app_model_config.id
db.session.commit()
db_session_with_containers.commit()
# Create message
message = Message(
@ -447,8 +445,8 @@ class TestAgentService:
currency="USD",
from_source="api",
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -457,7 +455,9 @@ class TestAgentService:
assert result is not None
assert result["meta"]["executor"] == end_user.name
def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_with_unknown_executor(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval when executor is unknown.
"""
@ -466,8 +466,6 @@ class TestAgentService:
# Create test data
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create conversation with non-existent account
conversation = Conversation(
id=fake.uuid4(),
@ -480,8 +478,8 @@ class TestAgentService:
mode="chat",
from_source="api",
)
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
# Create app model config
app_model_config = AppModelConfig(
@ -493,12 +491,12 @@ class TestAgentService:
agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}),
)
app_model_config.id = fake.uuid4()
db.session.add(app_model_config)
db.session.commit()
db_session_with_containers.add(app_model_config)
db_session_with_containers.commit()
# Update conversation with app model config
conversation.app_model_config_id = app_model_config.id
db.session.commit()
db_session_with_containers.commit()
# Create message
message = Message(
@ -519,8 +517,8 @@ class TestAgentService:
currency="USD",
from_source="api",
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -529,7 +527,9 @@ class TestAgentService:
assert result is not None
assert result["meta"]["executor"] == "Unknown"
def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_with_tool_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval with tool errors.
"""
@ -539,8 +539,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with tool error
thought_with_error = MessageAgentThought(
message_id=message.id,
@ -564,8 +562,8 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(thought_with_error)
db.session.commit()
db_session_with_containers.add(thought_with_error)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -580,7 +578,7 @@ class TestAgentService:
assert tool_call["error"] == "Tool execution failed"
def test_get_agent_logs_without_agent_thoughts(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval when message has no agent thoughts.
@ -600,7 +598,7 @@ class TestAgentService:
assert len(result["iterations"]) == 0
def test_get_agent_logs_app_model_config_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when app model config is not found.
@ -610,11 +608,9 @@ class TestAgentService:
# Create test data
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Remove app model config to test error handling
app.app_model_config_id = None
db.session.commit()
db_session_with_containers.commit()
# Create conversation without app model config
conversation = Conversation(
@ -629,8 +625,8 @@ class TestAgentService:
from_source="api",
app_model_config_id=None, # Explicitly set to None
)
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
# Create message
message = Message(
@ -651,15 +647,15 @@ class TestAgentService:
currency="USD",
from_source="api",
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
# Execute the method under test
with pytest.raises(ValueError, match="App model config not found"):
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
def test_get_agent_logs_agent_config_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when agent config is not found.
@ -677,7 +673,9 @@ class TestAgentService:
with pytest.raises(ValueError, match="Agent config not found"):
AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_list_agent_providers_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful listing of agent providers.
"""
@ -698,7 +696,7 @@ class TestAgentService:
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id))
def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of specific agent provider.
"""
@ -720,7 +718,9 @@ class TestAgentService:
mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value
mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name)
def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_provider_plugin_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when plugin daemon client raises an error.
"""
@ -741,7 +741,7 @@ class TestAgentService:
AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name)
def test_get_agent_logs_with_complex_tool_data(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval with complex tool data and multiple tools.
@ -752,8 +752,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with multiple tools
complex_thought = MessageAgentThought(
message_id=message.id,
@ -799,8 +797,8 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(complex_thought)
db.session.commit()
db_session_with_containers.add(complex_thought)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -831,7 +829,7 @@ class TestAgentService:
assert tool_calls[2]["status"] == "success"
assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon
def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test agent logs retrieval with message files and agent thought files.
"""
@ -842,7 +840,6 @@ class TestAgentService:
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from dify_graph.file import FileTransferMethod, FileType
from extensions.ext_database import db
from models.enums import CreatorUserRole
# Add files to message
@ -867,9 +864,9 @@ class TestAgentService:
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
db.session.add(message_file1)
db.session.add(message_file2)
db.session.commit()
db_session_with_containers.add(message_file1)
db_session_with_containers.add(message_file2)
db_session_with_containers.commit()
# Create agent thought with files
thought_with_files = MessageAgentThought(
@ -895,8 +892,8 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(thought_with_files)
db.session.commit()
db_session_with_containers.add(thought_with_files)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -912,7 +909,7 @@ class TestAgentService:
assert "file2" in iterations[0]["files"]
def test_get_agent_logs_with_different_timezone(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval with different timezone settings.
@ -938,7 +935,9 @@ class TestAgentService:
assert "T" in start_time # ISO format
assert "+08:00" in start_time or "Z" in start_time # Timezone offset
def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_with_empty_tool_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval with empty tool data.
"""
@ -948,8 +947,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with empty tool data
empty_thought = MessageAgentThought(
message_id=message.id,
@ -964,8 +961,8 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(empty_thought)
db.session.commit()
db_session_with_containers.add(empty_thought)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))
@ -979,7 +976,9 @@ class TestAgentService:
tool_calls = iterations[0]["tool_calls"]
assert len(tool_calls) == 0 # No tools to process
def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_agent_logs_with_malformed_json(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test agent logs retrieval with malformed JSON data in tool fields.
"""
@ -989,8 +988,6 @@ class TestAgentService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account)
from extensions.ext_database import db
# Create agent thought with malformed JSON
malformed_thought = MessageAgentThought(
message_id=message.id,
@ -1005,8 +1002,8 @@ class TestAgentService:
created_by_role="account",
created_by=message.from_account_id,
)
db.session.add(malformed_thought)
db.session.commit()
db_session_with_containers.add(malformed_thought)
db_session_with_containers.commit()
# Execute the method under test
result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id))

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from models import Account
@ -52,7 +53,7 @@ class TestAnnotationService:
"current_user": mock_user,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -115,11 +116,10 @@ class TestAnnotationService:
tenant_id,
)
def _create_test_conversation(self, app, account, fake):
def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake):
"""
Helper method to create a test conversation with all required fields.
"""
from extensions.ext_database import db
from models.model import Conversation
conversation = Conversation(
@ -141,17 +141,16 @@ class TestAnnotationService:
from_account_id=account.id,
)
db.session.add(conversation)
db.session.flush()
db_session_with_containers.add(conversation)
db_session_with_containers.flush()
return conversation
def _create_test_message(self, app, conversation, account, fake):
def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake):
"""
Helper method to create a test message with all required fields.
"""
import json
from extensions.ext_database import db
from models.model import Message
message = Message(
@ -180,12 +179,12 @@ class TestAnnotationService:
from_account_id=account.id,
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
return message
def test_insert_app_annotation_directly_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful direct insertion of app annotation.
@ -211,9 +210,8 @@ class TestAnnotationService:
assert annotation.id is not None
# Verify annotation was saved to database
from extensions.ext_database import db
db.session.refresh(annotation)
db_session_with_containers.refresh(annotation)
assert annotation.id is not None
# Verify add_annotation_to_index_task was called (when annotation setting exists)
@ -221,7 +219,7 @@ class TestAnnotationService:
mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_insert_app_annotation_directly_requires_question(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Question must be provided when inserting annotations directly.
@ -238,7 +236,7 @@ class TestAnnotationService:
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
def test_insert_app_annotation_directly_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test direct insertion of app annotation when app is not found.
@ -260,7 +258,7 @@ class TestAnnotationService:
AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id)
def test_update_app_annotation_directly_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful direct update of app annotation.
@ -298,7 +296,7 @@ class TestAnnotationService:
mock_external_service_dependencies["update_task"].delay.assert_not_called()
def test_up_insert_app_annotation_from_message_new(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creating new annotation from message.
@ -307,8 +305,8 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message first
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Setup annotation data with message_id
annotation_args = {
@ -333,7 +331,7 @@ class TestAnnotationService:
mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_up_insert_app_annotation_from_message_update(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test updating existing annotation from message.
@ -342,8 +340,8 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message first
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial annotation
initial_args = {
@ -373,7 +371,7 @@ class TestAnnotationService:
mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_up_insert_app_annotation_from_message_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creating annotation from message when app is not found.
@ -395,7 +393,7 @@ class TestAnnotationService:
AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id)
def test_get_annotation_list_by_app_id_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of annotation list by app ID.
@ -428,7 +426,7 @@ class TestAnnotationService:
assert annotation.account_id == account.id
def test_get_annotation_list_by_app_id_with_keyword(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of annotation list with keyword search.
@ -462,7 +460,7 @@ class TestAnnotationService:
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content
def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
r"""
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.
@ -534,7 +532,7 @@ class TestAnnotationService:
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)
def test_get_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of annotation list when app is not found.
@ -549,7 +547,9 @@ class TestAnnotationService:
with pytest.raises(NotFound, match="App not found"):
AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="")
def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful deletion of app annotation.
"""
@ -568,16 +568,19 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
# Verify annotation was deleted
from extensions.ext_database import db
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
deleted_annotation = (
db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
)
assert deleted_annotation is None
# Verify delete_annotation_index_task was called (when annotation setting exists)
# Note: In this test, no annotation setting exists, so task should not be called
mock_external_service_dependencies["delete_task"].delay.assert_not_called()
def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_app_annotation_app_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test deletion of app annotation when app is not found.
"""
@ -593,7 +596,7 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id)
def test_delete_app_annotation_annotation_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test deletion of app annotation when annotation is not found.
@ -606,7 +609,9 @@ class TestAnnotationService:
with pytest.raises(NotFound, match="Annotation not found"):
AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id)
def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_enable_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful enabling of app annotation.
"""
@ -632,7 +637,9 @@ class TestAnnotationService:
# Verify task was called
mock_external_service_dependencies["enable_task"].delay.assert_called_once()
def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_disable_app_annotation_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful disabling of app annotation.
"""
@ -651,7 +658,9 @@ class TestAnnotationService:
# Verify task was called
mock_external_service_dependencies["disable_task"].delay.assert_called_once()
def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies):
def test_enable_app_annotation_cached_job(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test enabling app annotation when job is already cached.
"""
@ -685,7 +694,9 @@ class TestAnnotationService:
# Clean up
redis_client.delete(enable_app_annotation_key)
def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_annotation_hit_histories_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of annotation hit histories.
"""
@ -728,7 +739,9 @@ class TestAnnotationService:
assert history.app_id == app.id
assert history.account_id == account.id
def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_add_annotation_history_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful addition of annotation history.
"""
@ -763,16 +776,15 @@ class TestAnnotationService:
)
# Verify hit count was incremented
from extensions.ext_database import db
db.session.refresh(annotation)
db_session_with_containers.refresh(annotation)
assert annotation.hit_count == initial_hit_count + 1
# Verify history was created
from models.model import AppAnnotationHitHistory
history = (
db.session.query(AppAnnotationHitHistory)
db_session_with_containers.query(AppAnnotationHitHistory)
.where(
AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id
)
@ -786,7 +798,9 @@ class TestAnnotationService:
assert history.score == score
assert history.source == "console"
def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_annotation_by_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of annotation by ID.
"""
@ -811,7 +825,9 @@ class TestAnnotationService:
assert retrieved_annotation.content == annotation_args["answer"]
assert retrieved_annotation.account_id == account.id
def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_batch_import_app_annotations_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful batch import of app annotations.
"""
@ -854,7 +870,7 @@ class TestAnnotationService:
mock_external_service_dependencies["batch_import_task"].delay.assert_called_once()
def test_batch_import_app_annotations_empty_file(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test batch import with empty CSV file.
@ -889,7 +905,7 @@ class TestAnnotationService:
assert "empty" in result["error_msg"].lower()
def test_batch_import_app_annotations_quota_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test batch import when quota is exceeded.
@ -935,7 +951,7 @@ class TestAnnotationService:
assert "limit" in result["error_msg"].lower()
def test_get_app_annotation_setting_by_app_id_enabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting enabled app annotation setting by app ID.
@ -944,7 +960,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting
@ -956,8 +971,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}",
)
collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding)
db.session.flush()
db_session_with_containers.add(collection_binding)
db_session_with_containers.flush()
# Create annotation setting
annotation_setting = AppAnnotationSetting(
@ -967,8 +982,8 @@ class TestAnnotationService:
created_user_id=account.id,
updated_user_id=account.id,
)
db.session.add(annotation_setting)
db.session.commit()
db_session_with_containers.add(annotation_setting)
db_session_with_containers.commit()
# Get annotation setting
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id)
@ -981,7 +996,7 @@ class TestAnnotationService:
assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002"
def test_get_app_annotation_setting_by_app_id_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting disabled app annotation setting by app ID.
@ -996,7 +1011,7 @@ class TestAnnotationService:
assert result["enabled"] is False
def test_update_app_annotation_setting_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful update of app annotation setting.
@ -1005,7 +1020,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting
@ -1017,8 +1031,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}",
)
collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding)
db.session.flush()
db_session_with_containers.add(collection_binding)
db_session_with_containers.flush()
# Create annotation setting
annotation_setting = AppAnnotationSetting(
@ -1028,8 +1042,8 @@ class TestAnnotationService:
created_user_id=account.id,
updated_user_id=account.id,
)
db.session.add(annotation_setting)
db.session.commit()
db_session_with_containers.add(annotation_setting)
db_session_with_containers.commit()
# Update annotation setting
update_args = {
@ -1046,11 +1060,11 @@ class TestAnnotationService:
assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002"
# Verify database was updated
db.session.refresh(annotation_setting)
db_session_with_containers.refresh(annotation_setting)
assert annotation_setting.score_threshold == 0.9
def test_export_annotation_list_by_app_id_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful export of annotation list by app ID.
@ -1083,7 +1097,7 @@ class TestAnnotationService:
assert annotation.created_at <= exported_annotations[i - 1].created_at
def test_export_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test export of annotation list when app is not found.
@ -1099,7 +1113,7 @@ class TestAnnotationService:
AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id)
def test_insert_app_annotation_directly_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful direct insertion of app annotation with annotation setting enabled.
@ -1108,7 +1122,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting
@ -1120,8 +1133,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}",
)
collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding)
db.session.flush()
db_session_with_containers.add(collection_binding)
db_session_with_containers.flush()
# Create annotation setting
annotation_setting = AppAnnotationSetting(
@ -1131,8 +1144,8 @@ class TestAnnotationService:
created_user_id=account.id,
updated_user_id=account.id,
)
db.session.add(annotation_setting)
db.session.commit()
db_session_with_containers.add(annotation_setting)
db_session_with_containers.commit()
# Setup annotation data
annotation_args = {
@ -1161,7 +1174,7 @@ class TestAnnotationService:
assert call_args[4] == collection_binding.id # collection_binding_id
def test_update_app_annotation_directly_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful direct update of app annotation with annotation setting enabled.
@ -1170,7 +1183,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting
@ -1182,8 +1194,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}",
)
collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding)
db.session.flush()
db_session_with_containers.add(collection_binding)
db_session_with_containers.flush()
# Create annotation setting
annotation_setting = AppAnnotationSetting(
@ -1193,8 +1205,8 @@ class TestAnnotationService:
created_user_id=account.id,
updated_user_id=account.id,
)
db.session.add(annotation_setting)
db.session.commit()
db_session_with_containers.add(annotation_setting)
db_session_with_containers.commit()
# First, create an annotation
original_args = {
@ -1234,7 +1246,7 @@ class TestAnnotationService:
assert call_args[4] == collection_binding.id # collection_binding_id
def test_delete_app_annotation_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful deletion of app annotation with annotation setting enabled.
@ -1243,7 +1255,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting
@ -1255,8 +1266,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}",
)
collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding)
db.session.flush()
db_session_with_containers.add(collection_binding)
db_session_with_containers.flush()
# Create annotation setting
annotation_setting = AppAnnotationSetting(
@ -1267,8 +1278,8 @@ class TestAnnotationService:
updated_user_id=account.id,
)
db.session.add(annotation_setting)
db.session.commit()
db_session_with_containers.add(annotation_setting)
db_session_with_containers.commit()
# Create an annotation first
annotation_args = {
@ -1285,7 +1296,9 @@ class TestAnnotationService:
AppAnnotationService.delete_app_annotation(app.id, annotation_id)
# Verify annotation was deleted
deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
deleted_annotation = (
db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
)
assert deleted_annotation is None
# Verify delete_annotation_index_task was called
@ -1297,7 +1310,7 @@ class TestAnnotationService:
assert call_args[3] == collection_binding.id # collection_binding_id
def test_up_insert_app_annotation_from_message_with_setting_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creating annotation from message with annotation setting enabled.
@ -1306,7 +1319,6 @@ class TestAnnotationService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create annotation setting first
from extensions.ext_database import db
from models.dataset import DatasetCollectionBinding
from models.model import AppAnnotationSetting
@ -1318,8 +1330,8 @@ class TestAnnotationService:
collection_name=f"annotation_collection_{fake.uuid4()}",
)
collection_binding.id = str(fake.uuid4())
db.session.add(collection_binding)
db.session.flush()
db_session_with_containers.add(collection_binding)
db_session_with_containers.flush()
# Create annotation setting
annotation_setting = AppAnnotationSetting(
@ -1329,12 +1341,12 @@ class TestAnnotationService:
created_user_id=account.id,
updated_user_id=account.id,
)
db.session.add(annotation_setting)
db.session.commit()
db_session_with_containers.add(annotation_setting)
db_session_with_containers.commit()
# Create a conversation and message first
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Setup annotation data with message_id
annotation_args = {

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.api_based_extension import APIBasedExtension
from services.account_service import AccountService, TenantService
@ -31,7 +32,7 @@ class TestAPIBasedExtensionService:
"requestor_instance": mock_requestor_instance,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -61,7 +62,7 @@ class TestAPIBasedExtensionService:
return account, tenant
def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful saving of API-based extension.
"""
@ -90,15 +91,16 @@ class TestAPIBasedExtensionService:
assert saved_extension.created_at is not None
# Verify extension was saved to database
from extensions.ext_database import db
db.session.refresh(saved_extension)
db_session_with_containers.refresh(saved_extension)
assert saved_extension.id is not None
# Verify ping connection was called
mock_external_service_dependencies["requestor_instance"].request.assert_called_once()
def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_validation_errors(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test validation errors when saving extension with invalid data.
"""
@ -132,7 +134,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_all_by_tenant_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of all extensions by tenant ID.
"""
@ -169,7 +173,7 @@ class TestAPIBasedExtensionService:
# Verify descending order (newer first)
assert extension.created_at <= extension_list[i - 1].created_at
def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of extension by tenant ID and extension ID.
"""
@ -200,7 +204,9 @@ class TestAPIBasedExtensionService:
assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted
assert retrieved_extension.created_at is not None
def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_with_tenant_id_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of extension when extension is not found.
"""
@ -214,7 +220,7 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful deletion of extension.
"""
@ -238,12 +244,15 @@ class TestAPIBasedExtensionService:
APIBasedExtensionService.delete(created_extension)
# Verify extension was deleted
from extensions.ext_database import db
deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
deleted_extension = (
db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first()
)
assert deleted_extension is None
def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test validation error when saving extension with duplicate name.
"""
@ -272,7 +281,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(extension_data2)
def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_update_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful update of existing extension.
"""
@ -329,7 +340,9 @@ class TestAPIBasedExtensionService:
assert retrieved_extension.api_endpoint == new_endpoint
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_connection_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test connection error when saving extension with invalid endpoint.
"""
@ -356,7 +369,7 @@ class TestAPIBasedExtensionService:
APIBasedExtensionService.save(extension_data)
def test_save_extension_invalid_api_key_length(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test validation error when saving extension with API key that is too short.
@ -378,7 +391,7 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test validation errors when saving extension with empty required fields.
"""
@ -412,7 +425,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_all_by_tenant_id_empty_list(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of extensions when no extensions exist for tenant.
"""
@ -428,7 +443,9 @@ class TestAPIBasedExtensionService:
assert len(extension_list) == 0
assert extension_list == []
def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_invalid_ping_response(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test validation error when ping response is invalid.
"""
@ -452,7 +469,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
APIBasedExtensionService.save(extension_data)
def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_extension_missing_ping_result(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test validation error when ping response is missing result field.
"""
@ -476,7 +495,9 @@ class TestAPIBasedExtensionService:
with pytest.raises(ValueError, match="{'status': 'ok'}"):
APIBasedExtensionService.save(extension_data)
def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_with_tenant_id_wrong_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of extension when tenant ID doesn't match.
"""

View File

@ -3,6 +3,7 @@ from unittest.mock import ANY, MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models.model import EndUser
@ -118,7 +119,9 @@ class TestAppGenerateService:
"global_dify_config": mock_global_dify_config,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"):
def _create_test_app_and_account(
self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat"
):
"""
Helper method to create a test app and account for testing.
@ -169,7 +172,7 @@ class TestAppGenerateService:
return app, account
def _create_test_workflow(self, db_session_with_containers, app):
def _create_test_workflow(self, db_session_with_containers: Session, app):
"""
Helper method to create a test workflow for testing.
@ -191,14 +194,14 @@ class TestAppGenerateService:
status="published",
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
return workflow
def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_completion_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful generation for completion mode app.
"""
@ -226,7 +229,7 @@ class TestAppGenerateService:
mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once()
def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful generation for chat mode app.
"""
@ -250,7 +253,9 @@ class TestAppGenerateService:
mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_agent_chat_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful generation for agent chat mode app.
"""
@ -274,7 +279,9 @@ class TestAppGenerateService:
mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once()
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_advanced_chat_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful generation for advanced chat mode app.
"""
@ -300,7 +307,9 @@ class TestAppGenerateService:
"advanced_chat_generator"
].return_value.convert_to_event_stream.assert_called_once()
def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_workflow_mode_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful generation for workflow mode app.
"""
@ -324,7 +333,9 @@ class TestAppGenerateService:
mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once()
mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_specific_workflow_id(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with a specific workflow ID.
"""
@ -355,7 +366,9 @@ class TestAppGenerateService:
"workflow_service"
].return_value.get_published_workflow_by_id.assert_called_once()
def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_debugger_invoke_from(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with debugger invoke from.
"""
@ -378,7 +391,9 @@ class TestAppGenerateService:
# Verify draft workflow was fetched for debugger
mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once()
def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_non_streaming_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with non-streaming mode.
"""
@ -401,7 +416,7 @@ class TestAppGenerateService:
# Verify rate limit exit was called for non-streaming mode
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test generation with EndUser instead of Account.
"""
@ -421,10 +436,8 @@ class TestAppGenerateService:
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
# Setup test arguments
args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"}
@ -438,7 +451,7 @@ class TestAppGenerateService:
assert result == ["test_response"]
def test_generate_with_billing_enabled_sandbox_plan(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with billing enabled and sandbox plan.
@ -466,7 +479,9 @@ class TestAppGenerateService:
# Verify billing service was called to consume quota
mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once()
def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_invalid_app_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with invalid app mode.
"""
@ -491,7 +506,7 @@ class TestAppGenerateService:
assert "Invalid app mode" in str(exc_info.value)
def test_generate_with_workflow_id_format_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with invalid workflow ID format.
@ -518,7 +533,7 @@ class TestAppGenerateService:
assert "Invalid workflow_id format" in str(exc_info.value)
def test_generate_with_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation when workflow is not found.
@ -552,7 +567,7 @@ class TestAppGenerateService:
assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value)
def test_generate_with_workflow_not_initialized_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation when workflow is not initialized for debugger.
@ -578,7 +593,7 @@ class TestAppGenerateService:
assert "Workflow not initialized" in str(exc_info.value)
def test_generate_with_workflow_not_published_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation when workflow is not published for non-debugger.
@ -604,7 +619,7 @@ class TestAppGenerateService:
assert "Workflow not published" in str(exc_info.value)
def test_generate_single_iteration_advanced_chat_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful single iteration generation for advanced chat mode.
@ -631,7 +646,7 @@ class TestAppGenerateService:
].return_value.single_iteration_generate.assert_called_once()
def test_generate_single_iteration_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful single iteration generation for workflow mode.
@ -658,7 +673,7 @@ class TestAppGenerateService:
].return_value.single_iteration_generate.assert_called_once()
def test_generate_single_iteration_invalid_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test single iteration generation with invalid app mode.
@ -681,7 +696,7 @@ class TestAppGenerateService:
assert "Invalid app mode" in str(exc_info.value)
def test_generate_single_loop_advanced_chat_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful single loop generation for advanced chat mode.
@ -708,7 +723,7 @@ class TestAppGenerateService:
].return_value.single_loop_generate.assert_called_once()
def test_generate_single_loop_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful single loop generation for workflow mode.
@ -732,7 +747,9 @@ class TestAppGenerateService:
# Verify workflow generator was called
mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once()
def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_single_loop_invalid_mode(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test single loop generation with invalid app mode.
"""
@ -753,7 +770,9 @@ class TestAppGenerateService:
# Verify error message
assert "Invalid app mode" in str(exc_info.value)
def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_more_like_this_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful more like this generation.
"""
@ -778,7 +797,7 @@ class TestAppGenerateService:
].return_value.generate_more_like_this.assert_called_once()
def test_generate_more_like_this_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test more like this generation with EndUser.
@ -799,10 +818,8 @@ class TestAppGenerateService:
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
message_id = fake.uuid4()
@ -815,7 +832,7 @@ class TestAppGenerateService:
assert result == ["more_like_this_response"]
def test_get_max_active_requests_with_app_limit(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting max active requests with app-specific limit.
@ -835,7 +852,7 @@ class TestAppGenerateService:
assert result == 10
def test_get_max_active_requests_with_config_limit(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting max active requests with config limit being smaller.
@ -856,7 +873,7 @@ class TestAppGenerateService:
assert result <= 100
def test_get_max_active_requests_with_zero_limits(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting max active requests with zero limits (infinite).
@ -875,7 +892,9 @@ class TestAppGenerateService:
# Verify the result (should return config limit when app limit is 0)
assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS
def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_exception_cleanup(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that rate limit exit is called when an exception occurs.
"""
@ -904,7 +923,9 @@ class TestAppGenerateService:
# Verify rate limit exit was called for cleanup
mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once()
def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_agent_mode_detection(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with agent mode detection based on app configuration.
"""
@ -932,7 +953,7 @@ class TestAppGenerateService:
mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once()
def test_generate_with_different_invoke_from_values(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test generation with different invoke from values.
@ -962,7 +983,7 @@ class TestAppGenerateService:
# Verify the result
assert result == ["test_response"]
def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies):
def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test generation with complex arguments including files and external trace ID.
"""

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from constants.model_template import default_app_templates
from models import Account
@ -44,7 +45,7 @@ class TestAppService:
"account_feature_service": mock_account_feature_service,
}
def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app creation with basic parameters.
"""
@ -98,7 +99,9 @@ class TestAppService:
assert app.is_public is False
assert app.is_universal is False
def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_app_with_different_modes(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app creation with different app modes.
"""
@ -141,7 +144,7 @@ class TestAppService:
assert app.tenant_id == tenant.id
assert app.created_by == account.id
def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app retrieval.
"""
@ -189,7 +192,7 @@ class TestAppService:
assert retrieved_app.tenant_id == created_app.tenant_id
assert retrieved_app.created_by == created_app.created_by
def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful paginated app list retrieval.
"""
@ -243,7 +246,9 @@ class TestAppService:
assert app.tenant_id == tenant.id
assert app.mode == "chat"
def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_paginate_apps_with_filters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test paginated app list with various filters.
"""
@ -316,7 +321,9 @@ class TestAppService:
my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args)
assert len(my_apps.items) == 1
def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_paginate_apps_with_tag_filters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test paginated app list with tag filters.
"""
@ -386,7 +393,7 @@ class TestAppService:
# Should return None when no apps match tag filter
assert paginated_apps is None
def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app update with all fields.
"""
@ -455,7 +462,7 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app name update.
"""
@ -508,7 +515,7 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app icon update.
"""
@ -565,7 +572,9 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_app_site_status_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful app site status update.
"""
@ -623,7 +632,9 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_app_api_status_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful app API status update.
"""
@ -681,7 +692,9 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_app_site_status_no_change(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app site status update when status doesn't change.
"""
@ -732,7 +745,7 @@ class TestAppService:
assert updated_app.tenant_id == app.tenant_id
assert updated_app.created_by == app.created_by
def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app deletion.
"""
@ -778,12 +791,13 @@ class TestAppService:
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first()
deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first()
assert deleted_app is None
def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_app_with_related_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app deletion with related data cleanup.
"""
@ -839,12 +853,11 @@ class TestAppService:
mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id)
# Verify app was deleted from database
from extensions.ext_database import db
deleted_app = db.session.query(App).filter_by(id=app_id).first()
deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first()
assert deleted_app is None
def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app metadata retrieval.
"""
@ -883,7 +896,7 @@ class TestAppService:
assert "tool_icons" in app_meta
# Note: get_app_meta currently only returns tool_icons
def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app code retrieval by app ID.
"""
@ -923,7 +936,7 @@ class TestAppService:
assert app_code is not None
assert len(app_code) > 0
def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful app ID retrieval by app code.
"""
@ -963,10 +976,9 @@ class TestAppService:
site.status = "normal"
site.default_language = "en-US"
site.customize_token_strategy = "uuid"
from extensions.ext_database import db
db.session.add(site)
db.session.commit()
db_session_with_containers.add(site)
db_session_with_containers.commit()
# Get app ID by code
app_id = AppService.get_app_id_by_code(site.code)
@ -974,7 +986,7 @@ class TestAppService:
# Verify app ID was retrieved correctly
assert app_id == app.id
def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test app creation with invalid mode.
"""
@ -1010,7 +1022,7 @@ class TestAppService:
app_service.create_app(tenant.id, app_args, account)
def test_get_apps_with_special_characters_in_name(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
r"""
Test app retrieval with special characters in name search to verify SQL injection prevention.

View File

@ -9,10 +9,10 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline
from services.dataset_service import DatasetService
@ -25,7 +25,9 @@ class DatasetServiceIntegrationDataFactory:
"""Factory for creating real database entities used by integration tests."""
@staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
def create_account_with_tenant(
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER
) -> tuple[Account, Tenant]:
"""Create an account and tenant, then bind the account as current tenant member."""
account = Account(
email=f"{uuid4()}@example.com",
@ -34,8 +36,8 @@ class DatasetServiceIntegrationDataFactory:
status="active",
)
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add_all([account, tenant])
db.session.flush()
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
@ -43,8 +45,8 @@ class DatasetServiceIntegrationDataFactory:
role=role,
current=True,
)
db.session.add(join)
db.session.flush()
db_session_with_containers.add(join)
db_session_with_containers.flush()
# Keep tenant context on the in-memory user without opening a separate session.
account.role = role
@ -53,6 +55,7 @@ class DatasetServiceIntegrationDataFactory:
@staticmethod
def create_dataset(
db_session_with_containers: Session,
tenant_id: str,
created_by: str,
name: str = "Test Dataset",
@ -82,12 +85,14 @@ class DatasetServiceIntegrationDataFactory:
collection_binding_id=collection_binding_id,
chunk_structure=chunk_structure,
)
db.session.add(dataset)
db.session.flush()
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
return dataset
@staticmethod
def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document:
def create_document(
db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt"
) -> Document:
"""Create a document row belonging to the given dataset."""
document = Document(
tenant_id=dataset.tenant_id,
@ -102,8 +107,8 @@ class DatasetServiceIntegrationDataFactory:
indexing_status="completed",
doc_form="text_model",
)
db.session.add(document)
db.session.flush()
db_session_with_containers.add(document)
db_session_with_containers.flush()
return document
@staticmethod
@ -118,10 +123,10 @@ class DatasetServiceIntegrationDataFactory:
class TestDatasetServiceCreateDataset:
"""Integration coverage for DatasetService.create_empty_dataset."""
def test_create_internal_dataset_basic_success(self, db_session_with_containers):
def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session):
"""Create a basic internal dataset with minimal configuration."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
# Act
result = DatasetService.create_empty_dataset(
@ -133,17 +138,17 @@ class TestDatasetServiceCreateDataset:
)
# Assert
created_dataset = db.session.get(Dataset, result.id)
created_dataset = db_session_with_containers.get(Dataset, result.id)
assert created_dataset is not None
assert created_dataset.provider == "vendor"
assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME
assert created_dataset.embedding_model_provider is None
assert created_dataset.embedding_model is None
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers):
def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session):
"""Create an internal dataset with economy indexing and no embedding model."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
# Act
result = DatasetService.create_empty_dataset(
@ -155,15 +160,15 @@ class TestDatasetServiceCreateDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.indexing_technique == "economy"
assert result.embedding_model_provider is None
assert result.embedding_model is None
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers):
def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session):
"""Create a high-quality dataset and persist embedding model settings."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
# Act
@ -179,7 +184,7 @@ class TestDatasetServiceCreateDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model_name
@ -188,11 +193,12 @@ class TestDatasetServiceCreateDataset:
model_type=ModelType.TEXT_EMBEDDING,
)
def test_create_dataset_duplicate_name_error(self, db_session_with_containers):
def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session):
"""Raise duplicate-name error when the same tenant already has the name."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name="Duplicate Dataset",
@ -209,10 +215,10 @@ class TestDatasetServiceCreateDataset:
account=account,
)
def test_create_external_dataset_success(self, db_session_with_containers):
def test_create_external_dataset_success(self, db_session_with_containers: Session):
"""Create an external dataset and persist external knowledge binding."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
external_knowledge_api_id = str(uuid4())
external_knowledge_id = "knowledge-123"
@ -231,16 +237,16 @@ class TestDatasetServiceCreateDataset:
)
# Assert
binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first()
assert result.provider == "external"
assert binding is not None
assert binding.external_knowledge_id == external_knowledge_id
assert binding.external_knowledge_api_id == external_knowledge_api_id
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers):
def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session):
"""Create a high-quality dataset with retrieval/reranking settings."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model()
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH,
@ -271,14 +277,16 @@ class TestDatasetServiceCreateDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.retrieval_model == retrieval_model.model_dump()
mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0")
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(self, db_session_with_containers):
def test_create_internal_dataset_with_high_quality_indexing_custom_embedding(
self, db_session_with_containers: Session
):
"""Create high-quality dataset with explicitly configured embedding model."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
embedding_provider = "openai"
embedding_model_name = "text-embedding-3-small"
embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model(
@ -303,7 +311,7 @@ class TestDatasetServiceCreateDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_provider
assert result.embedding_model == embedding_model_name
@ -315,10 +323,10 @@ class TestDatasetServiceCreateDataset:
model=embedding_model_name,
)
def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers):
def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session):
"""Persist retrieval model settings when creating an internal dataset."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
retrieval_model = RetrievalModel(
search_method=RetrievalMethod.SEMANTIC_SEARCH,
reranking_enable=False,
@ -338,13 +346,13 @@ class TestDatasetServiceCreateDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.retrieval_model == retrieval_model.model_dump()
def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers):
def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session):
"""Persist canonical custom permission when creating an internal dataset."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
# Act
result = DatasetService.create_empty_dataset(
@ -357,13 +365,13 @@ class TestDatasetServiceCreateDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.permission == DatasetPermissionEnum.ALL_TEAM
def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers):
def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session):
"""Raise error when external API template does not exist."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
external_knowledge_api_id = str(uuid4())
# Act / Assert
@ -381,10 +389,10 @@ class TestDatasetServiceCreateDataset:
external_knowledge_id="knowledge-123",
)
def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session):
"""Raise error when external knowledge id is missing for external dataset creation."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
external_knowledge_api_id = str(uuid4())
# Act / Assert
@ -406,10 +414,10 @@ class TestDatasetServiceCreateDataset:
class TestDatasetServiceCreateRagPipelineDataset:
"""Integration coverage for DatasetService.create_empty_rag_pipeline_dataset."""
def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session):
"""Create rag-pipeline dataset and pipeline rows when a name is provided."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name="RAG Pipeline Dataset",
@ -425,8 +433,8 @@ class TestDatasetServiceCreateRagPipelineDataset:
)
# Assert
created_dataset = db.session.get(Dataset, result.id)
created_pipeline = db.session.get(Pipeline, result.pipeline_id)
created_dataset = db_session_with_containers.get(Dataset, result.id)
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert created_dataset is not None
assert created_dataset.name == entity.name
assert created_dataset.runtime_mode == "rag_pipeline"
@ -436,10 +444,10 @@ class TestDatasetServiceCreateRagPipelineDataset:
assert created_pipeline.name == entity.name
assert created_pipeline.created_by == account.id
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session):
"""Create rag-pipeline dataset with generated incremental name when input name is empty."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
generated_name = "Untitled 1"
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
@ -460,25 +468,26 @@ class TestDatasetServiceCreateRagPipelineDataset:
)
# Assert
db.session.refresh(result)
created_pipeline = db.session.get(Pipeline, result.pipeline_id)
db_session_with_containers.refresh(result)
created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id)
assert result.name == generated_name
assert created_pipeline is not None
assert created_pipeline.name == generated_name
mock_generate_name.assert_called_once()
def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session):
"""Raise duplicate-name error when rag-pipeline dataset name already exists."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
duplicate_name = "Duplicate RAG Dataset"
DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name=duplicate_name,
indexing_technique=None,
)
db.session.commit()
db_session_with_containers.commit()
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name=duplicate_name,
@ -496,10 +505,10 @@ class TestDatasetServiceCreateRagPipelineDataset:
tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity
)
def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session):
"""Persist canonical custom permission for rag-pipeline dataset creation."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji")
entity = RagPipelineDatasetCreateEntity(
name="Custom Permission RAG Dataset",
@ -515,13 +524,13 @@ class TestDatasetServiceCreateRagPipelineDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.permission == DatasetPermissionEnum.ALL_TEAM
def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers):
def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session):
"""Persist icon metadata when creating rag-pipeline dataset."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
icon_info = IconInfo(
icon="📚",
icon_background="#E8F5E9",
@ -542,23 +551,25 @@ class TestDatasetServiceCreateRagPipelineDataset:
)
# Assert
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.icon_info == icon_info.model_dump()
class TestDatasetServiceUpdateAndDeleteDataset:
"""Integration coverage for SQL-backed update and delete behavior."""
def test_update_dataset_duplicate_name_error(self, db_session_with_containers):
def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session):
"""Reject update when target name already exists within the same tenant."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
source_dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name="Source Dataset",
)
DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name="Existing Dataset",
@ -568,17 +579,20 @@ class TestDatasetServiceUpdateAndDeleteDataset:
with pytest.raises(ValueError, match="Dataset name already exists"):
DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account)
def test_delete_dataset_with_documents_success(self, db_session_with_containers):
def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session):
"""Delete a dataset that already has documents."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
chunk_structure="text_model",
)
DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id)
DatasetServiceIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, created_by=account.id
)
# Act
with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal:
@ -586,14 +600,15 @@ class TestDatasetServiceUpdateAndDeleteDataset:
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
assert db_session_with_containers.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset)
def test_delete_empty_dataset_success(self, db_session_with_containers):
def test_delete_empty_dataset_success(self, db_session_with_containers: Session):
"""Delete a dataset that has no documents and no indexing technique."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique=None,
@ -606,14 +621,15 @@ class TestDatasetServiceUpdateAndDeleteDataset:
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
assert db_session_with_containers.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset)
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers):
def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session):
"""Delete dataset when indexing_technique is None but doc_form path still exists."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique=None,
@ -626,17 +642,17 @@ class TestDatasetServiceUpdateAndDeleteDataset:
# Assert
assert result is True
assert db.session.get(Dataset, dataset.id) is None
assert db_session_with_containers.get(Dataset, dataset.id) is None
dataset_deleted_signal.send.assert_called_once_with(dataset)
class TestDatasetServiceRetrievalConfiguration:
"""Integration coverage for retrieval configuration persistence."""
def test_get_dataset_retrieval_configuration(self, db_session_with_containers):
def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session):
"""Return retrieval configuration that is persisted in SQL."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
retrieval_model = {
"search_method": "semantic_search",
"top_k": 5,
@ -644,6 +660,7 @@ class TestDatasetServiceRetrievalConfiguration:
"reranking_enable": True,
}
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
retrieval_model=retrieval_model,
@ -658,11 +675,12 @@ class TestDatasetServiceRetrievalConfiguration:
assert result.retrieval_model["search_method"] == "semantic_search"
assert result.retrieval_model["top_k"] == 5
def test_update_dataset_retrieval_configuration(self, db_session_with_containers):
def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session):
"""Persist retrieval configuration updates through DatasetService.update_dataset."""
# Arrange
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant()
account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetServiceIntegrationDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
indexing_technique="high_quality",
@ -684,6 +702,6 @@ class TestDatasetServiceRetrievalConfiguration:
result = DatasetService.update_dataset(dataset.id, update_data, account)
# Assert
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert result.id == dataset.id
assert dataset.retrieval_model == update_data["retrieval_model"]

View File

@ -11,8 +11,8 @@ from unittest.mock import call, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.dataset import Dataset, Document
from services.dataset_service import DocumentService
from services.errors.document import DocumentIndexingError
@ -32,6 +32,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
@staticmethod
def create_dataset(
db_session_with_containers: Session,
dataset_id: str | None = None,
tenant_id: str | None = None,
name: str = "Test Dataset",
@ -47,12 +48,13 @@ class DocumentBatchUpdateIntegrationDataFactory:
if dataset_id:
dataset.id = dataset_id
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_document(
db_session_with_containers: Session,
dataset: Dataset,
document_id: str | None = None,
name: str = "test_document.pdf",
@ -89,13 +91,14 @@ class DocumentBatchUpdateIntegrationDataFactory:
for key, value in kwargs.items():
setattr(document, key, value)
db.session.add(document)
db_session_with_containers.add(document)
if commit:
db.session.commit()
db_session_with_containers.commit()
return document
@staticmethod
def create_multiple_documents(
db_session_with_containers: Session,
dataset: Dataset,
document_ids: list[str],
enabled: bool = True,
@ -106,6 +109,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
documents: list[Document] = []
for index, doc_id in enumerate(document_ids, start=1):
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
document_id=doc_id,
name=f"document_{doc_id}.pdf",
@ -116,7 +120,7 @@ class DocumentBatchUpdateIntegrationDataFactory:
commit=False,
)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
return documents
@staticmethod
@ -173,13 +177,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
assert document.archived_at is None
assert document.archived_by is None
def test_batch_update_enable_documents_success(self, db_session_with_containers, patched_dependencies):
def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Enable disabled documents and trigger indexing side effects."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document_ids = [str(uuid4()), str(uuid4())]
disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
db_session_with_containers,
dataset=dataset,
document_ids=document_ids,
enabled=False,
@ -192,7 +197,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Assert
for document in disabled_docs:
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_enabled(document, FIXED_TIME)
expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
@ -203,13 +208,15 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls)
def test_batch_update_enable_already_enabled_document_skipped(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Skip enable operation for already-enabled documents."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True)
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True
)
# Act
DocumentService.batch_update_document_status(
@ -220,18 +227,19 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.enabled is True
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_disable_documents_success(self, db_session_with_containers, patched_dependencies):
def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Disable completed documents and trigger remove-index tasks."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document_ids = [str(uuid4()), str(uuid4())]
enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
db_session_with_containers,
dataset=dataset,
document_ids=document_ids,
enabled=True,
@ -248,7 +256,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Assert
for document in enabled_docs:
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_disabled(document, user.id, FIXED_TIME)
expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids]
@ -259,13 +267,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls)
def test_batch_update_disable_already_disabled_document_skipped(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Skip disable operation for already-disabled documents."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
enabled=False,
indexing_status="completed",
@ -281,17 +290,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(disabled_doc)
db_session_with_containers.refresh(disabled_doc)
assert disabled_doc.enabled is False
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_disable_non_completed_document_error(self, db_session_with_containers, patched_dependencies):
def test_batch_update_disable_non_completed_document_error(
self, db_session_with_containers: Session, patched_dependencies
):
"""Raise error when disabling a non-completed document."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
enabled=True,
indexing_status="indexing",
@ -307,13 +319,13 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
user=user,
)
def test_batch_update_archive_documents_success(self, db_session_with_containers, patched_dependencies):
def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Archive enabled documents and trigger remove-index task."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=False
db_session_with_containers, dataset=dataset, enabled=True, archived=False
)
# Act
@ -325,21 +337,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_archived(document, user.id, FIXED_TIME)
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1)
patched_dependencies["remove_task"].delay.assert_called_once_with(document.id)
def test_batch_update_archive_already_archived_document_skipped(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Skip archive operation for already-archived documents."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=True
db_session_with_containers, dataset=dataset, enabled=True, archived=True
)
# Act
@ -351,20 +363,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.archived is True
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_archive_disabled_document_no_index_removal(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Archive disabled document without index-removal side effects."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=False, archived=False
db_session_with_containers, dataset=dataset, enabled=False, archived=False
)
# Act
@ -376,18 +388,18 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_archived(document, user.id, FIXED_TIME)
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["remove_task"].delay.assert_not_called()
def test_batch_update_unarchive_documents_success(self, db_session_with_containers, patched_dependencies):
def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies):
"""Unarchive enabled documents and trigger add-index task."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=True
db_session_with_containers, dataset=dataset, enabled=True, archived=True
)
# Act
@ -399,7 +411,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_unarchived(document)
assert document.updated_at == FIXED_TIME
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
@ -407,14 +419,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["add_task"].delay.assert_called_once_with(document.id)
def test_batch_update_unarchive_already_unarchived_document_skipped(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Skip unarchive operation for already-unarchived documents."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=True, archived=False
db_session_with_containers, dataset=dataset, enabled=True, archived=False
)
# Act
@ -426,20 +438,20 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.archived is False
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_unarchive_disabled_document_no_index_addition(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Unarchive disabled document without index-add side effects."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
dataset=dataset, enabled=False, archived=True
db_session_with_containers, dataset=dataset, enabled=False, archived=True
)
# Act
@ -451,20 +463,21 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_unarchived(document)
assert document.updated_at == FIXED_TIME
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_document_indexing_error_redis_cache_hit(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Raise DocumentIndexingError when redis indicates active indexing."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
name="test_document.pdf",
enabled=True,
@ -483,12 +496,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
assert "test_document.pdf" in str(exc_info.value)
patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing")
def test_batch_update_async_task_error_handling(self, db_session_with_containers, patched_dependencies):
def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies):
"""Persist DB update, then propagate async task error."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False)
document = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=False
)
patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error")
# Act / Assert
@ -500,14 +515,14 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
user=user,
)
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_enabled(document, FIXED_TIME)
patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1)
def test_batch_update_empty_document_list(self, db_session_with_containers, patched_dependencies):
def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies):
"""Return early when document_ids is empty."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
# Act
@ -520,10 +535,10 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["redis_client"].get.assert_not_called()
patched_dependencies["redis_client"].setex.assert_not_called()
def test_batch_update_document_not_found_skipped(self, db_session_with_containers, patched_dependencies):
def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies):
"""Skip IDs that do not map to existing dataset documents."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
missing_document_id = str(uuid4())
@ -540,18 +555,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["redis_client"].setex.assert_not_called()
patched_dependencies["add_task"].delay.assert_not_called()
def test_batch_update_mixed_document_states_and_actions(self, db_session_with_containers, patched_dependencies):
def test_batch_update_mixed_document_states_and_actions(
self, db_session_with_containers: Session, patched_dependencies
):
"""Process only the applicable document in a mixed-state enable batch."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False)
disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=False
)
enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
enabled=True,
position=2,
)
archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
enabled=True,
archived=True,
@ -568,9 +589,9 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(disabled_doc)
db.session.refresh(enabled_doc)
db.session.refresh(archived_doc)
db_session_with_containers.refresh(disabled_doc)
db_session_with_containers.refresh(enabled_doc)
db_session_with_containers.refresh(archived_doc)
self._assert_document_enabled(disabled_doc, FIXED_TIME)
assert enabled_doc.enabled is True
assert archived_doc.enabled is True
@ -582,13 +603,16 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id)
def test_batch_update_large_document_list_performance(self, db_session_with_containers, patched_dependencies):
def test_batch_update_large_document_list_performance(
self, db_session_with_containers: Session, patched_dependencies
):
"""Handle large document lists with consistent updates and side effects."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
document_ids = [str(uuid4()) for _ in range(100)]
documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents(
db_session_with_containers,
dataset=dataset,
document_ids=document_ids,
enabled=False,
@ -604,7 +628,7 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
# Assert
for document in documents:
db.session.refresh(document)
db_session_with_containers.refresh(document)
self._assert_document_enabled(document, FIXED_TIME)
assert patched_dependencies["redis_client"].setex.call_count == len(document_ids)
@ -616,17 +640,26 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls)
def test_batch_update_mixed_document_states_complex_scenario(
self, db_session_with_containers, patched_dependencies
self, db_session_with_containers: Session, patched_dependencies
):
"""Process a complex mixed-state batch and update only eligible records."""
# Arrange
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset()
dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers)
user = DocumentBatchUpdateIntegrationDataFactory.create_user()
doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=False)
doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=2)
doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=3)
doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(dataset=dataset, enabled=True, position=4)
doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=False
)
doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True, position=2
)
doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True, position=3
)
doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers, dataset=dataset, enabled=True, position=4
)
doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document(
db_session_with_containers,
dataset=dataset,
enabled=True,
archived=True,
@ -645,11 +678,11 @@ class TestDatasetServiceBatchUpdateDocumentStatus:
)
# Assert
db.session.refresh(doc1)
db.session.refresh(doc2)
db.session.refresh(doc3)
db.session.refresh(doc4)
db.session.refresh(doc5)
db_session_with_containers.refresh(doc1)
db_session_with_containers.refresh(doc2)
db_session_with_containers.refresh(doc3)
db_session_with_containers.refresh(doc4)
db_session_with_containers.refresh(doc5)
self._assert_document_enabled(doc1, FIXED_TIME)
assert doc2.enabled is True
assert doc3.enabled is True

View File

@ -10,7 +10,8 @@ Tests the retrieval of document segments with pagination and filtering:
from uuid import uuid4
from extensions.ext_database import db
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
from services.dataset_service import SegmentService
@ -23,6 +24,7 @@ class SegmentServiceTestDataFactory:
@staticmethod
def create_account_with_tenant(
db_session_with_containers: Session,
role: TenantAccountRole = TenantAccountRole.OWNER,
tenant: Tenant | None = None,
) -> tuple[Account, Tenant]:
@ -33,13 +35,13 @@ class SegmentServiceTestDataFactory:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
if tenant is None:
tenant = Tenant(name=f"tenant-{uuid4()}", status="normal")
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
@ -47,14 +49,14 @@ class SegmentServiceTestDataFactory:
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_dataset(tenant_id: str, created_by: str) -> Dataset:
def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset:
"""Create a real dataset."""
dataset = Dataset(
tenant_id=tenant_id,
@ -67,12 +69,14 @@ class SegmentServiceTestDataFactory:
provider="vendor",
retrieval_model={"top_k": 2},
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document:
def create_document(
db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str
) -> Document:
"""Create a real document."""
document = Document(
tenant_id=tenant_id,
@ -84,12 +88,13 @@ class SegmentServiceTestDataFactory:
created_from="api",
created_by=created_by,
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@staticmethod
def create_segment(
db_session_with_containers: Session,
tenant_id: str,
dataset_id: str,
document_id: str,
@ -112,8 +117,8 @@ class SegmentServiceTestDataFactory:
tokens=tokens,
created_by=created_by,
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
@ -130,7 +135,7 @@ class TestSegmentServiceGetSegments:
- Combined filters
"""
def test_get_segments_basic_pagination(self, db_session_with_containers):
def test_get_segments_basic_pagination(self, db_session_with_containers: Session):
"""
Test basic pagination functionality.
@ -140,11 +145,14 @@ class TestSegmentServiceGetSegments:
- Returns segments and total count
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
segment1 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -153,6 +161,7 @@ class TestSegmentServiceGetSegments:
content="First segment",
)
segment2 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -170,7 +179,7 @@ class TestSegmentServiceGetSegments:
assert items[0].id == segment1.id
assert items[1].id == segment2.id
def test_get_segments_with_status_filter(self, db_session_with_containers):
def test_get_segments_with_status_filter(self, db_session_with_containers: Session):
"""
Test filtering by status list.
@ -179,11 +188,14 @@ class TestSegmentServiceGetSegments:
- Only segments with matching status are returned
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -192,6 +204,7 @@ class TestSegmentServiceGetSegments:
status="completed",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -200,6 +213,7 @@ class TestSegmentServiceGetSegments:
status="indexing",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -219,7 +233,7 @@ class TestSegmentServiceGetSegments:
statuses = {item.status for item in items}
assert statuses == {"completed", "indexing"}
def test_get_segments_with_empty_status_list(self, db_session_with_containers):
def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session):
"""
Test with empty status list.
@ -228,11 +242,14 @@ class TestSegmentServiceGetSegments:
- No status filter is applied to avoid WHERE false condition
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -241,6 +258,7 @@ class TestSegmentServiceGetSegments:
status="completed",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -256,7 +274,7 @@ class TestSegmentServiceGetSegments:
assert len(items) == 2
assert total == 2
def test_get_segments_with_keyword_search(self, db_session_with_containers):
def test_get_segments_with_keyword_search(self, db_session_with_containers: Session):
"""
Test keyword search functionality.
@ -265,11 +283,14 @@ class TestSegmentServiceGetSegments:
- Search pattern includes wildcards (%keyword%)
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -278,6 +299,7 @@ class TestSegmentServiceGetSegments:
content="This contains search term in the middle",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -294,7 +316,7 @@ class TestSegmentServiceGetSegments:
assert total == 1
assert "search term" in items[0].content
def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers):
def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session):
"""
Test ordering by position and id.
@ -304,12 +326,15 @@ class TestSegmentServiceGetSegments:
- This prevents duplicate data across pages when positions are not unique
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
# Create segments with different positions
seg_pos2 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -318,6 +343,7 @@ class TestSegmentServiceGetSegments:
content="Position 2",
)
seg_pos1 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -326,6 +352,7 @@ class TestSegmentServiceGetSegments:
content="Position 1",
)
seg_pos3 = SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -344,7 +371,7 @@ class TestSegmentServiceGetSegments:
assert items[1].id == seg_pos2.id
assert items[2].id == seg_pos3.id
def test_get_segments_empty_results(self, db_session_with_containers):
def test_get_segments_empty_results(self, db_session_with_containers: Session):
"""
Test when no segments match the criteria.
@ -353,7 +380,7 @@ class TestSegmentServiceGetSegments:
- Total count is 0
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
non_existent_doc_id = str(uuid4())
# Act
@ -363,7 +390,7 @@ class TestSegmentServiceGetSegments:
assert items == []
assert total == 0
def test_get_segments_combined_filters(self, db_session_with_containers):
def test_get_segments_combined_filters(self, db_session_with_containers: Session):
"""
Test with multiple filters combined.
@ -372,12 +399,15 @@ class TestSegmentServiceGetSegments:
- Status list and keyword search both applied
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
# Create segments with various statuses and content
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -387,6 +417,7 @@ class TestSegmentServiceGetSegments:
content="This is important information",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -396,6 +427,7 @@ class TestSegmentServiceGetSegments:
content="This is also important",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -421,7 +453,7 @@ class TestSegmentServiceGetSegments:
assert items[0].status == "completed"
assert "important" in items[0].content
def test_get_segments_with_none_status_list(self, db_session_with_containers):
def test_get_segments_with_none_status_list(self, db_session_with_containers: Session):
"""
Test with None status list.
@ -430,11 +462,14 @@ class TestSegmentServiceGetSegments:
- No status filter is applied
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -443,6 +478,7 @@ class TestSegmentServiceGetSegments:
status="completed",
)
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
@ -462,7 +498,7 @@ class TestSegmentServiceGetSegments:
assert len(items) == 2
assert total == 2
def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers):
def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session):
"""
Test that max_per_page is correctly set to 100.
@ -471,13 +507,16 @@ class TestSegmentServiceGetSegments:
- This prevents excessive page sizes
"""
# Arrange
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant()
dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id)
owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id)
document = SegmentServiceTestDataFactory.create_document(
db_session_with_containers, tenant.id, dataset.id, owner.id
)
# Create 105 segments to exceed max_per_page of 100
for i in range(105):
SegmentServiceTestDataFactory.create_segment(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,

View File

@ -13,7 +13,8 @@ This test suite covers:
import json
from uuid import uuid4
from extensions.ext_database import db
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@ -31,7 +32,9 @@ class DatasetRetrievalTestDataFactory:
"""Factory class for creating database-backed test data for dataset retrieval integration tests."""
@staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]:
def create_account_with_tenant(
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL
) -> tuple[Account, Tenant]:
"""Create an account and tenant with the specified role."""
account = Account(
email=f"{uuid4()}@example.com",
@ -43,8 +46,8 @@ class DatasetRetrievalTestDataFactory:
name=f"tenant-{uuid4()}",
status="normal",
)
db.session.add_all([account, tenant])
db.session.flush()
db_session_with_containers.add_all([account, tenant])
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
@ -52,14 +55,16 @@ class DatasetRetrievalTestDataFactory:
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account:
def create_account_in_tenant(
db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER
) -> Account:
"""Create an account and add it to an existing tenant."""
account = Account(
email=f"{uuid4()}@example.com",
@ -67,8 +72,8 @@ class DatasetRetrievalTestDataFactory:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.flush()
db_session_with_containers.add(account)
db_session_with_containers.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
@ -76,14 +81,15 @@ class DatasetRetrievalTestDataFactory:
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account
@staticmethod
def create_dataset(
db_session_with_containers: Session,
tenant_id: str,
created_by: str,
name: str = "Test Dataset",
@ -101,12 +107,14 @@ class DatasetRetrievalTestDataFactory:
provider="vendor",
retrieval_model={"top_k": 2},
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission:
def create_dataset_permission(
db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str
) -> DatasetPermission:
"""Create a dataset permission."""
permission = DatasetPermission(
dataset_id=dataset_id,
@ -114,12 +122,14 @@ class DatasetRetrievalTestDataFactory:
account_id=account_id,
has_permission=True,
)
db.session.add(permission)
db.session.commit()
db_session_with_containers.add(permission)
db_session_with_containers.commit()
return permission
@staticmethod
def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule:
def create_process_rule(
db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict
) -> DatasetProcessRule:
"""Create a dataset process rule."""
process_rule = DatasetProcessRule(
dataset_id=dataset_id,
@ -127,12 +137,14 @@ class DatasetRetrievalTestDataFactory:
mode=mode,
rules=json.dumps(rules),
)
db.session.add(process_rule)
db.session.commit()
db_session_with_containers.add(process_rule)
db_session_with_containers.commit()
return process_rule
@staticmethod
def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery:
def create_dataset_query(
db_session_with_containers: Session, dataset_id: str, created_by: str, content: str
) -> DatasetQuery:
"""Create a dataset query."""
dataset_query = DatasetQuery(
dataset_id=dataset_id,
@ -142,23 +154,23 @@ class DatasetRetrievalTestDataFactory:
created_by_role="account",
created_by=created_by,
)
db.session.add(dataset_query)
db.session.commit()
db_session_with_containers.add(dataset_query)
db_session_with_containers.commit()
return dataset_query
@staticmethod
def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin:
def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin:
"""Create an app-dataset join."""
join = AppDatasetJoin(
app_id=str(uuid4()),
dataset_id=dataset_id,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
return join
@staticmethod
def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag:
def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag:
"""Create a knowledge tag and bind it to the target dataset."""
tag = Tag(
tenant_id=tenant_id,
@ -166,8 +178,8 @@ class DatasetRetrievalTestDataFactory:
name=f"tag-{uuid4()}",
created_by=created_by,
)
db.session.add(tag)
db.session.flush()
db_session_with_containers.add(tag)
db_session_with_containers.flush()
binding = TagBinding(
tenant_id=tenant_id,
@ -175,8 +187,8 @@ class DatasetRetrievalTestDataFactory:
target_id=target_id,
created_by=created_by,
)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(binding)
db_session_with_containers.commit()
return tag
@ -195,15 +207,16 @@ class TestDatasetServiceGetDatasets:
# ==================== Basic Retrieval Tests ====================
def test_get_datasets_basic_pagination(self, db_session_with_containers):
def test_get_datasets_basic_pagination(self, db_session_with_containers: Session):
"""Test basic pagination without user or filters."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1
per_page = 20
for i in range(5):
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name=f"Dataset {i}",
@ -217,21 +230,23 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 5
assert total == 5
def test_get_datasets_with_search(self, db_session_with_containers):
def test_get_datasets_with_search(self, db_session_with_containers: Session):
"""Test get_datasets with search keyword."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1
per_page = 20
search = "test"
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name="Test Dataset",
permission=DatasetPermissionEnum.ALL_TEAM,
)
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name="Another Dataset",
@ -245,26 +260,32 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1
assert total == 1
def test_get_datasets_with_tag_filtering(self, db_session_with_containers):
def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session):
"""Test get_datasets with tag_ids filtering."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1
per_page = 20
dataset_1 = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
dataset_2 = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id)
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id)
tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(
db_session_with_containers, tenant.id, account.id, dataset_1.id
)
tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(
db_session_with_containers, tenant.id, account.id, dataset_2.id
)
tag_ids = [tag_1.id, tag_2.id]
# Act
@ -274,16 +295,17 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 2
assert total == 2
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers):
def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session):
"""Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1
per_page = 20
tag_ids = []
for i in range(3):
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
name=f"dataset-{i}",
@ -300,19 +322,21 @@ class TestDatasetServiceGetDatasets:
# ==================== Permission-Based Filtering Tests ====================
def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers):
def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session):
"""Test that without user, only ALL_TEAM datasets are shown."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
page = 1
per_page = 20
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
permission=DatasetPermissionEnum.ALL_TEAM,
)
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=account.id,
permission=DatasetPermissionEnum.ONLY_ME,
@ -325,15 +349,18 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1
assert total == 1
def test_get_datasets_owner_with_include_all(self, db_session_with_containers):
def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session):
"""Test that OWNER with include_all=True sees all datasets."""
# Arrange
owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
for i, permission in enumerate(
[DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM]
):
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
name=f"dataset-{i}",
@ -353,12 +380,15 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 3
assert total == 3
def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers):
def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session):
"""Test that normal user sees ONLY_ME datasets they created."""
# Arrange
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.NORMAL
)
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
permission=DatasetPermissionEnum.ONLY_ME,
@ -371,13 +401,18 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1
assert total == 1
def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers):
def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session):
"""Test that normal user sees ALL_TEAM datasets."""
# Arrange
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.NORMAL
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
)
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ALL_TEAM,
@ -390,18 +425,25 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1
assert total == 1
def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers):
def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session):
"""Test that normal user sees PARTIAL_TEAM datasets they have permission for."""
# Arrange
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.NORMAL
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.PARTIAL_TEAM,
)
DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id)
DatasetRetrievalTestDataFactory.create_dataset_permission(
db_session_with_containers, dataset.id, tenant.id, user.id
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user)
@ -410,20 +452,25 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1
assert total == 1
def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers):
def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session):
"""Test that DATASET_OPERATOR only sees datasets they have explicit permission for."""
# Arrange
operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.DATASET_OPERATOR
db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ONLY_ME,
)
DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id)
DatasetRetrievalTestDataFactory.create_dataset_permission(
db_session_with_containers, dataset.id, tenant.id, operator.id
)
# Act
datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator)
@ -432,14 +479,17 @@ class TestDatasetServiceGetDatasets:
assert len(datasets) == 1
assert total == 1
def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers):
def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session):
"""Test that DATASET_OPERATOR without permissions returns empty result."""
# Arrange
operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(
role=TenantAccountRole.DATASET_OPERATOR
db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(
db_session_with_containers, tenant, role=TenantAccountRole.OWNER
)
owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER)
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
permission=DatasetPermissionEnum.ALL_TEAM,
@ -456,11 +506,13 @@ class TestDatasetServiceGetDatasets:
class TestDatasetServiceGetDataset:
"""Comprehensive integration tests for DatasetService.get_dataset method."""
def test_get_dataset_success(self, db_session_with_containers):
def test_get_dataset_success(self, db_session_with_containers: Session):
"""Test successful retrieval of a single dataset."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
# Act
result = DatasetService.get_dataset(dataset.id)
@ -469,7 +521,7 @@ class TestDatasetServiceGetDataset:
assert result is not None
assert result.id == dataset.id
def test_get_dataset_not_found(self, db_session_with_containers):
def test_get_dataset_not_found(self, db_session_with_containers: Session):
"""Test retrieval when dataset doesn't exist."""
# Arrange
dataset_id = str(uuid4())
@ -484,12 +536,15 @@ class TestDatasetServiceGetDataset:
class TestDatasetServiceGetDatasetsByIds:
"""Comprehensive integration tests for DatasetService.get_datasets_by_ids method."""
def test_get_datasets_by_ids_success(self, db_session_with_containers):
def test_get_datasets_by_ids_success(self, db_session_with_containers: Session):
"""Test successful bulk retrieval of datasets by IDs."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
datasets = [
DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3)
DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
for _ in range(3)
]
dataset_ids = [dataset.id for dataset in datasets]
@ -501,7 +556,7 @@ class TestDatasetServiceGetDatasetsByIds:
assert total == 3
assert all(dataset.id in dataset_ids for dataset in result_datasets)
def test_get_datasets_by_ids_empty_list(self, db_session_with_containers):
def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session):
"""Test get_datasets_by_ids with empty list returns empty result."""
# Arrange
tenant_id = str(uuid4())
@ -514,7 +569,7 @@ class TestDatasetServiceGetDatasetsByIds:
assert datasets == []
assert total == 0
def test_get_datasets_by_ids_none_list(self, db_session_with_containers):
def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session):
"""Test get_datasets_by_ids with None returns empty result."""
# Arrange
tenant_id = str(uuid4())
@ -530,17 +585,20 @@ class TestDatasetServiceGetDatasetsByIds:
class TestDatasetServiceGetProcessRules:
"""Comprehensive integration tests for DatasetService.get_process_rules method."""
def test_get_process_rules_with_existing_rule(self, db_session_with_containers):
def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session):
"""Test retrieval of process rules when rule exists."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
rules_data = {
"pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}],
"segmentation": {"delimiter": "\n", "max_tokens": 500},
}
DatasetRetrievalTestDataFactory.create_process_rule(
db_session_with_containers,
dataset_id=dataset.id,
created_by=account.id,
mode="custom",
@ -554,11 +612,13 @@ class TestDatasetServiceGetProcessRules:
assert result["mode"] == "custom"
assert result["rules"] == rules_data
def test_get_process_rules_without_existing_rule(self, db_session_with_containers):
def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session):
"""Test retrieval of process rules when no rule exists (returns defaults)."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
# Act
result = DatasetService.get_process_rules(dataset.id)
@ -572,16 +632,19 @@ class TestDatasetServiceGetProcessRules:
class TestDatasetServiceGetDatasetQueries:
"""Comprehensive integration tests for DatasetService.get_dataset_queries method."""
def test_get_dataset_queries_success(self, db_session_with_containers):
def test_get_dataset_queries_success(self, db_session_with_containers: Session):
"""Test successful retrieval of dataset queries."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
page = 1
per_page = 20
for i in range(3):
DatasetRetrievalTestDataFactory.create_dataset_query(
db_session_with_containers,
dataset_id=dataset.id,
created_by=account.id,
content=f"query-{i}",
@ -595,11 +658,13 @@ class TestDatasetServiceGetDatasetQueries:
assert total == 3
assert all(query.dataset_id == dataset.id for query in queries)
def test_get_dataset_queries_empty_result(self, db_session_with_containers):
def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session):
"""Test retrieval when no queries exist."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
page = 1
per_page = 20
@ -614,14 +679,16 @@ class TestDatasetServiceGetDatasetQueries:
class TestDatasetServiceGetRelatedApps:
"""Comprehensive integration tests for DatasetService.get_related_apps method."""
def test_get_related_apps_success(self, db_session_with_containers):
def test_get_related_apps_success(self, db_session_with_containers: Session):
"""Test successful retrieval of related apps."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
for _ in range(2):
DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id)
DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id)
# Act
result = DatasetService.get_related_apps(dataset.id)
@ -630,11 +697,13 @@ class TestDatasetServiceGetRelatedApps:
assert len(result) == 2
assert all(join.dataset_id == dataset.id for join in result)
def test_get_related_apps_empty_result(self, db_session_with_containers):
def test_get_related_apps_empty_result(self, db_session_with_containers: Session):
"""Test retrieval when no related apps exist."""
# Arrange
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant()
dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id)
account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetRetrievalTestDataFactory.create_dataset(
db_session_with_containers, tenant_id=tenant.id, created_by=account.id
)
# Act
result = DatasetService.get_related_apps(dataset.id)

View File

@ -2,9 +2,9 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
@ -15,7 +15,9 @@ class DatasetUpdateTestDataFactory:
"""Factory class for creating real test data for dataset update integration tests."""
@staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
def create_account_with_tenant(
db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER
) -> tuple[Account, Tenant]:
"""Create a real account and tenant with the given role."""
account = Account(
email=f"{uuid4()}@example.com",
@ -23,12 +25,12 @@ class DatasetUpdateTestDataFactory:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
@ -36,14 +38,15 @@ class DatasetUpdateTestDataFactory:
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_dataset(
db_session_with_containers: Session,
tenant_id: str,
created_by: str,
provider: str = "vendor",
@ -71,12 +74,13 @@ class DatasetUpdateTestDataFactory:
embedding_model=embedding_model,
collection_binding_id=collection_binding_id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@staticmethod
def create_external_binding(
db_session_with_containers: Session,
tenant_id: str,
dataset_id: str,
created_by: str,
@ -93,8 +97,8 @@ class DatasetUpdateTestDataFactory:
external_knowledge_id=external_knowledge_id,
external_knowledge_api_id=external_knowledge_api_id,
)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(binding)
db_session_with_containers.commit()
return binding
@ -112,10 +116,11 @@ class TestDatasetServiceUpdateDataset:
# ==================== External Dataset Tests ====================
def test_update_external_dataset_success(self, db_session_with_containers):
def test_update_external_dataset_success(self, db_session_with_containers: Session):
"""Test successful update of external dataset."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="external",
@ -124,12 +129,13 @@ class TestDatasetServiceUpdateDataset:
retrieval_model="old_model",
)
binding = DatasetUpdateTestDataFactory.create_external_binding(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=user.id,
)
binding_id = binding.id
db.session.expunge(binding)
db_session_with_containers.expunge(binding)
update_data = {
"name": "new_name",
@ -142,8 +148,8 @@ class TestDatasetServiceUpdateDataset:
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
db_session_with_containers.refresh(dataset)
updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
assert dataset.name == "new_name"
assert dataset.description == "new_description"
@ -153,15 +159,17 @@ class TestDatasetServiceUpdateDataset:
assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
assert result.id == dataset.id
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session):
"""Test error when external knowledge id is missing."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="external",
)
DatasetUpdateTestDataFactory.create_external_binding(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=user.id,
@ -173,17 +181,19 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge id is required" in str(context.value)
db.session.rollback()
db_session_with_containers.rollback()
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers):
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session):
"""Test error when external knowledge api id is missing."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="external",
)
DatasetUpdateTestDataFactory.create_external_binding(
db_session_with_containers,
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=user.id,
@ -195,12 +205,13 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge api id is required" in str(context.value)
db.session.rollback()
db_session_with_containers.rollback()
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers):
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session):
"""Test error when external knowledge binding is not found."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="external",
@ -216,15 +227,16 @@ class TestDatasetServiceUpdateDataset:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge binding not found" in str(context.value)
db.session.rollback()
db_session_with_containers.rollback()
# ==================== Internal Dataset Basic Tests ====================
def test_update_internal_dataset_basic_success(self, db_session_with_containers):
def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session):
"""Test successful update of internal dataset with basic fields."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
@ -244,7 +256,7 @@ class TestDatasetServiceUpdateDataset:
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.description == "new_description"
@ -254,11 +266,12 @@ class TestDatasetServiceUpdateDataset:
assert dataset.embedding_model == "text-embedding-ada-002"
assert result.id == dataset.id
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers):
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session):
"""Test that None values are filtered out except for description field."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
@ -278,7 +291,7 @@ class TestDatasetServiceUpdateDataset:
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.description is None
@ -289,11 +302,12 @@ class TestDatasetServiceUpdateDataset:
# ==================== Indexing Technique Switch Tests ====================
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers):
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session):
"""Test updating internal dataset indexing technique to economy."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
@ -312,7 +326,7 @@ class TestDatasetServiceUpdateDataset:
result = DatasetService.update_dataset(dataset.id, update_data, user)
mock_task.delay.assert_called_once_with(dataset.id, "remove")
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "economy"
assert dataset.embedding_model is None
assert dataset.embedding_model_provider is None
@ -320,10 +334,11 @@ class TestDatasetServiceUpdateDataset:
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers):
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session):
"""Test updating internal dataset indexing technique to high_quality."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
@ -366,7 +381,7 @@ class TestDatasetServiceUpdateDataset:
mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002")
mock_task.delay.assert_called_once_with(dataset.id, "add")
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
@ -380,9 +395,10 @@ class TestDatasetServiceUpdateDataset:
self, db_session_with_containers
):
"""Test preserving embedding settings when indexing technique remains unchanged."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
@ -399,7 +415,7 @@ class TestDatasetServiceUpdateDataset:
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.indexing_technique == "high_quality"
@ -409,11 +425,12 @@ class TestDatasetServiceUpdateDataset:
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers):
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session):
"""Test updating internal dataset with new embedding model."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
@ -465,7 +482,7 @@ class TestDatasetServiceUpdateDataset:
regenerate_vectors_only=True,
)
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.embedding_model == "text-embedding-3-small"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == binding.id
@ -474,9 +491,9 @@ class TestDatasetServiceUpdateDataset:
# ==================== Error Handling Tests ====================
def test_update_dataset_not_found_error(self, db_session_with_containers):
def test_update_dataset_not_found_error(self, db_session_with_containers: Session):
"""Test error when dataset is not found."""
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
update_data = {"name": "new_name"}
with pytest.raises(ValueError) as context:
@ -484,11 +501,16 @@ class TestDatasetServiceUpdateDataset:
assert "Dataset not found" in str(context.value)
def test_update_dataset_permission_error(self, db_session_with_containers):
def test_update_dataset_permission_error(self, db_session_with_containers: Session):
"""Test error when user doesn't have permission."""
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.OWNER
)
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(
db_session_with_containers, role=TenantAccountRole.NORMAL
)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=owner.id,
provider="vendor",
@ -500,10 +522,11 @@ class TestDatasetServiceUpdateDataset:
with pytest.raises(NoPermissionError):
DatasetService.update_dataset(dataset.id, update_data, outsider)
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers):
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session):
"""Test error when embedding model is not available."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers)
dataset = DatasetUpdateTestDataFactory.create_dataset(
db_session_with_containers,
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",

View File

@ -5,6 +5,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy import Engine
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -19,7 +20,7 @@ class TestFileService:
"""Integration tests for FileService using testcontainers."""
@pytest.fixture
def engine(self, db_session_with_containers):
def engine(self, db_session_with_containers: Session):
bind = db_session_with_containers.get_bind()
assert isinstance(bind, Engine)
return bind
@ -46,7 +47,7 @@ class TestFileService:
"extract_processor": mock_extract_processor,
}
def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account for testing.
@ -67,18 +68,16 @@ class TestFileService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
@ -89,15 +88,15 @@ class TestFileService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account
def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test end user for testing.
@ -118,14 +117,14 @@ class TestFileService:
session_id=fake.uuid4(),
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
return end_user
def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account):
def _create_test_upload_file(
self, db_session_with_containers: Session, mock_external_service_dependencies, account
):
"""
Helper method to create a test upload file for testing.
@ -155,15 +154,13 @@ class TestFileService:
source_url="",
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
# Test upload_file method
def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies):
"""
Test successful file upload with valid parameters.
"""
@ -196,7 +193,9 @@ class TestFileService:
assert upload_file.id is not None
def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_with_end_user(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with end user instead of account.
"""
@ -219,7 +218,7 @@ class TestFileService:
assert upload_file.created_by_role == CreatorUserRole.END_USER
def test_upload_file_with_datasets_source(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with datasets source parameter.
@ -244,7 +243,7 @@ class TestFileService:
assert upload_file.source_url == "https://example.com/source"
def test_upload_file_invalid_filename_characters(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with invalid filename characters.
@ -265,7 +264,7 @@ class TestFileService:
)
def test_upload_file_filename_too_long(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with filename that exceeds length limit.
@ -295,7 +294,7 @@ class TestFileService:
assert len(base_name) <= 200
def test_upload_file_datasets_unsupported_type(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload for datasets with unsupported file type.
@ -316,7 +315,9 @@ class TestFileService:
source="datasets",
)
def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_too_large(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with file size exceeding limit.
"""
@ -338,7 +339,7 @@ class TestFileService:
# Test is_file_size_within_limit method
def test_is_file_size_within_limit_image_success(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size check for image files within limit.
@ -351,7 +352,7 @@ class TestFileService:
assert result is True
def test_is_file_size_within_limit_video_success(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size check for video files within limit.
@ -364,7 +365,7 @@ class TestFileService:
assert result is True
def test_is_file_size_within_limit_audio_success(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size check for audio files within limit.
@ -377,7 +378,7 @@ class TestFileService:
assert result is True
def test_is_file_size_within_limit_document_success(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size check for document files within limit.
@ -390,7 +391,7 @@ class TestFileService:
assert result is True
def test_is_file_size_within_limit_image_exceeded(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size check for image files exceeding limit.
@ -403,7 +404,7 @@ class TestFileService:
assert result is False
def test_is_file_size_within_limit_unknown_extension(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size check for unknown file extension.
@ -416,7 +417,7 @@ class TestFileService:
assert result is True
# Test upload_text method
def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies):
"""
Test successful text upload.
"""
@ -447,7 +448,9 @@ class TestFileService:
# Verify storage was called
mock_external_service_dependencies["storage"].save.assert_called_once()
def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_text_name_too_long(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test text upload with name that exceeds length limit.
"""
@ -472,7 +475,9 @@ class TestFileService:
assert upload_file.name == "a" * 200
# Test get_file_preview method
def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_get_file_preview_success(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test successful file preview generation.
"""
@ -484,9 +489,8 @@ class TestFileService:
# Update file to have document extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
result = FileService(engine).get_file_preview(file_id=upload_file.id)
@ -494,7 +498,7 @@ class TestFileService:
mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once()
def test_get_file_preview_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file preview with non-existent file.
@ -506,7 +510,7 @@ class TestFileService:
FileService(engine).get_file_preview(file_id=non_existent_id)
def test_get_file_preview_unsupported_file_type(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file preview with unsupported file type.
@ -519,15 +523,14 @@ class TestFileService:
# Update file to have non-document extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
with pytest.raises(UnsupportedFileTypeError):
FileService(engine).get_file_preview(file_id=upload_file.id)
def test_get_file_preview_text_truncation(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file preview with text that exceeds preview limit.
@ -540,9 +543,8 @@ class TestFileService:
# Update file to have document extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock long text content
long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT
@ -554,7 +556,9 @@ class TestFileService:
assert result == "x" * 3000
# Test get_image_preview method
def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_get_image_preview_success(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test successful image preview generation.
"""
@ -566,9 +570,8 @@ class TestFileService:
# Update file to have image extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
timestamp = "1234567890"
nonce = "test_nonce"
@ -586,7 +589,7 @@ class TestFileService:
mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once()
def test_get_image_preview_invalid_signature(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test image preview with invalid signature.
@ -613,7 +616,7 @@ class TestFileService:
)
def test_get_image_preview_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test image preview with non-existent file.
@ -634,7 +637,7 @@ class TestFileService:
)
def test_get_image_preview_unsupported_file_type(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test image preview with non-image file type.
@ -647,9 +650,8 @@ class TestFileService:
# Update file to have non-image extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
timestamp = "1234567890"
nonce = "test_nonce"
@ -665,7 +667,7 @@ class TestFileService:
# Test get_file_generator_by_file_id method
def test_get_file_generator_by_file_id_success(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test successful file generator retrieval.
@ -692,7 +694,7 @@ class TestFileService:
mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once()
def test_get_file_generator_by_file_id_invalid_signature(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file generator retrieval with invalid signature.
@ -719,7 +721,7 @@ class TestFileService:
)
def test_get_file_generator_by_file_id_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file generator retrieval with non-existent file.
@ -741,7 +743,7 @@ class TestFileService:
# Test get_public_image_preview method
def test_get_public_image_preview_success(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test successful public image preview generation.
@ -754,9 +756,8 @@ class TestFileService:
# Update file to have image extension
upload_file.extension = "jpg"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id)
@ -765,7 +766,7 @@ class TestFileService:
mock_external_service_dependencies["storage"].load.assert_called_once()
def test_get_public_image_preview_file_not_found(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test public image preview with non-existent file.
@ -777,7 +778,7 @@ class TestFileService:
FileService(engine).get_public_image_preview(file_id=non_existent_id)
def test_get_public_image_preview_unsupported_file_type(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test public image preview with non-image file type.
@ -790,15 +791,16 @@ class TestFileService:
# Update file to have non-image extension
upload_file.extension = "pdf"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
with pytest.raises(UnsupportedFileTypeError):
FileService(engine).get_public_image_preview(file_id=upload_file.id)
# Test edge cases and boundary conditions
def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_empty_content(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with empty content.
"""
@ -820,7 +822,7 @@ class TestFileService:
assert upload_file.size == 0
def test_upload_file_special_characters_in_name(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with special characters in filename (but valid ones).
@ -843,7 +845,7 @@ class TestFileService:
assert upload_file.name == filename
def test_upload_file_different_case_extensions(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with different case extensions.
@ -865,7 +867,9 @@ class TestFileService:
assert upload_file is not None
assert upload_file.extension == "pdf" # Should be converted to lowercase
def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_text_empty_text(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test text upload with empty text.
"""
@ -888,7 +892,9 @@ class TestFileService:
assert upload_file is not None
assert upload_file.size == 0
def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_file_size_limits_edge_cases(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file size limits with edge case values.
"""
@ -908,7 +914,9 @@ class TestFileService:
result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size)
assert result is False
def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_with_source_url(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with source URL that gets overridden by signed URL.
"""
@ -946,7 +954,7 @@ class TestFileService:
# Test file extension blacklist
def test_upload_file_blocked_extension(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with blocked extension.
@ -969,7 +977,7 @@ class TestFileService:
)
def test_upload_file_blocked_extension_case_insensitive(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with blocked extension (case insensitive).
@ -992,7 +1000,9 @@ class TestFileService:
user=account,
)
def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_not_in_blacklist(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with extension not in blacklist.
"""
@ -1016,7 +1026,9 @@ class TestFileService:
assert upload_file.name == filename
assert upload_file.extension == "pdf"
def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies):
def test_upload_file_empty_blacklist(
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with empty blacklist (default behavior).
"""
@ -1041,7 +1053,7 @@ class TestFileService:
assert upload_file.extension == "sh"
def test_upload_file_multiple_blocked_extensions(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with multiple blocked extensions.
@ -1066,7 +1078,7 @@ class TestFileService:
)
def test_upload_file_no_extension_with_blacklist(
self, db_session_with_containers, engine, mock_external_service_dependencies
self, db_session_with_containers: Session, engine, mock_external_service_dependencies
):
"""
Test file upload with no extension when blacklist is configured.

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.model import MessageFeedback
from services.app_service import AppService
@ -69,7 +70,7 @@ class TestMessageService:
# "current_user": mock_current_user,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -127,11 +128,10 @@ class TestMessageService:
# mock_external_service_dependencies["current_user"].id = account_id
# mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id
def _create_test_conversation(self, app, account, fake):
def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake):
"""
Helper method to create a test conversation with all required fields.
"""
from extensions.ext_database import db
from models.model import Conversation
conversation = Conversation(
@ -153,17 +153,16 @@ class TestMessageService:
from_account_id=account.id,
)
db.session.add(conversation)
db.session.flush()
db_session_with_containers.add(conversation)
db_session_with_containers.flush()
return conversation
def _create_test_message(self, app, conversation, account, fake):
def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake):
"""
Helper method to create a test message with all required fields.
"""
import json
from extensions.ext_database import db
from models.model import Message
message = Message(
@ -192,11 +191,13 @@ class TestMessageService:
from_account_id=account.id,
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
return message
def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_first_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination by first ID.
"""
@ -204,10 +205,10 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
messages.append(message)
# Test pagination by first ID
@ -228,7 +229,9 @@ class TestMessageService:
# Verify messages are in ascending order
assert result.data[0].created_at <= result.data[1].created_at
def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_first_id_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by first ID when no user is provided.
"""
@ -246,7 +249,7 @@ class TestMessageService:
assert result.has_more is False
def test_pagination_by_first_id_no_conversation_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by first ID when no conversation ID is provided.
@ -265,7 +268,7 @@ class TestMessageService:
assert result.has_more is False
def test_pagination_by_first_id_invalid_first_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by first ID with invalid first_id.
@ -274,8 +277,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test pagination with invalid first_id
with pytest.raises(FirstMessageNotExistsError):
@ -287,7 +290,9 @@ class TestMessageService:
limit=10,
)
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_last_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination by last ID.
"""
@ -295,10 +300,10 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
messages.append(message)
# Test pagination by last ID
@ -319,7 +324,7 @@ class TestMessageService:
assert result.data[0].created_at >= result.data[1].created_at
def test_pagination_by_last_id_with_include_ids(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by last ID with include_ids filter.
@ -328,10 +333,10 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and multiple messages
conversation = self._create_test_conversation(app, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
messages = []
for i in range(5):
message = self._create_test_message(app, conversation, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
messages.append(message)
# Test pagination with include_ids
@ -347,7 +352,9 @@ class TestMessageService:
for message in result.data:
assert message.id in include_ids
def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_last_id_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by last ID when no user is provided.
"""
@ -363,7 +370,7 @@ class TestMessageService:
assert result.has_more is False
def test_pagination_by_last_id_invalid_last_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by last ID with invalid last_id.
@ -372,8 +379,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test pagination with invalid last_id
with pytest.raises(LastMessageNotExistsError):
@ -385,7 +392,7 @@ class TestMessageService:
conversation_id=conversation.id,
)
def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful creation of feedback.
"""
@ -393,8 +400,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create feedback
rating = "like"
@ -413,7 +420,7 @@ class TestMessageService:
assert feedback.from_account_id == account.id
assert feedback.from_end_user_id is None
def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test creating feedback when no user is provided.
"""
@ -421,8 +428,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"):
@ -430,7 +437,9 @@ class TestMessageService:
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
)
def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_feedback_update_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test updating existing feedback.
"""
@ -438,8 +447,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback
initial_rating = "like"
@ -462,7 +471,9 @@ class TestMessageService:
assert updated_feedback.rating != initial_rating
assert updated_feedback.content != initial_content
def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_feedback_delete_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test deleting existing feedback by setting rating to None.
"""
@ -470,8 +481,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback
feedback = MessageService.create_feedback(
@ -482,13 +493,14 @@ class TestMessageService:
MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None)
# Verify feedback was deleted
from extensions.ext_database import db
deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
deleted_feedback = (
db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first()
)
assert deleted_feedback is None
def test_create_feedback_no_rating_when_not_exists(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creating feedback with no rating when feedback doesn't exist.
@ -497,8 +509,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test creating feedback with no rating when no feedback exists
with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"):
@ -506,7 +518,9 @@ class TestMessageService:
app_model=app, message_id=message.id, user=account, rating=None, content=None
)
def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_all_messages_feedbacks_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of all message feedbacks.
"""
@ -516,8 +530,8 @@ class TestMessageService:
# Create multiple conversations and messages with feedbacks
feedbacks = []
for i in range(3):
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
feedback = MessageService.create_feedback(
app_model=app,
@ -539,7 +553,7 @@ class TestMessageService:
assert result[i]["created_at"] >= result[i + 1]["created_at"]
def test_get_all_messages_feedbacks_pagination(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination of message feedbacks.
@ -549,8 +563,8 @@ class TestMessageService:
# Create multiple conversations and messages with feedbacks
for i in range(5):
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
@ -569,7 +583,7 @@ class TestMessageService:
page_2_ids = {feedback["id"] for feedback in result_page_2}
assert len(page_1_ids.intersection(page_2_ids)) == 0
def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of message.
"""
@ -577,8 +591,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Get message
retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id)
@ -590,7 +604,7 @@ class TestMessageService:
assert retrieved_message.from_source == "console"
assert retrieved_message.from_account_id == account.id
def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test getting message that doesn't exist.
"""
@ -601,7 +615,7 @@ class TestMessageService:
with pytest.raises(MessageNotExistsError):
MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4())
def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test getting message with wrong user (different account).
"""
@ -609,8 +623,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create another account
from services.account_service import AccountService, TenantService
@ -628,7 +642,7 @@ class TestMessageService:
MessageService.get_message(app_model=app, user=other_account, message_id=message.id)
def test_get_suggested_questions_after_answer_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful generation of suggested questions after answer.
@ -637,8 +651,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock the LLMGenerator to return specific questions
mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"]
@ -665,7 +679,7 @@ class TestMessageService:
mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once()
def test_get_suggested_questions_after_answer_no_user(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting suggested questions when no user is provided.
@ -674,8 +688,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Test getting suggested questions with no user
from core.app.entities.app_invoke_entities import InvokeFrom
@ -686,7 +700,7 @@ class TestMessageService:
)
def test_get_suggested_questions_after_answer_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting suggested questions when feature is disabled.
@ -695,8 +709,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock the feature to be disabled
mock_external_service_dependencies[
@ -712,7 +726,7 @@ class TestMessageService:
)
def test_get_suggested_questions_after_answer_no_workflow(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting suggested questions when no workflow exists.
@ -721,8 +735,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock no workflow
mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None
@ -738,7 +752,7 @@ class TestMessageService:
assert result == []
def test_get_suggested_questions_after_answer_debugger_mode(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting suggested questions in debugger mode.
@ -747,8 +761,8 @@ class TestMessageService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
# Create a conversation and message
conversation = self._create_test_conversation(app, account, fake)
message = self._create_test_message(app, conversation, account, fake)
conversation = self._create_test_conversation(db_session_with_containers, app, account, fake)
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Mock questions
mock_questions = ["Debug question 1", "Debug question 2"]

View File

@ -6,9 +6,9 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
@ -40,25 +40,25 @@ class TestMessagesCleanServiceIntegration:
PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:"
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before and after each test to ensure isolation."""
yield
# Clear all test data in correct order (respecting foreign key constraints)
db.session.query(DatasetRetrieverResource).delete()
db.session.query(AppAnnotationHitHistory).delete()
db.session.query(SavedMessage).delete()
db.session.query(MessageFile).delete()
db.session.query(MessageAgentThought).delete()
db.session.query(MessageChain).delete()
db.session.query(MessageAnnotation).delete()
db.session.query(MessageFeedback).delete()
db.session.query(Message).delete()
db.session.query(Conversation).delete()
db.session.query(App).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@pytest.fixture(autouse=True)
def cleanup_redis(self):
@ -100,7 +100,7 @@ class TestMessagesCleanServiceIntegration:
with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False):
yield
def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX):
def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX):
"""Helper to create account and tenant."""
fake = Faker()
@ -110,28 +110,28 @@ class TestMessagesCleanServiceIntegration:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.flush()
db_session_with_containers.add(account)
db_session_with_containers.flush()
tenant = Tenant(
name=fake.company(),
plan=str(plan),
status="normal",
)
db.session.add(tenant)
db.session.flush()
db_session_with_containers.add(tenant)
db_session_with_containers.flush()
tenant_account_join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
db_session_with_containers.add(tenant_account_join)
db_session_with_containers.commit()
return account, tenant
def _create_app(self, tenant, account):
def _create_app(self, db_session_with_containers: Session, tenant, account):
"""Helper to create an app."""
fake = Faker()
@ -149,12 +149,12 @@ class TestMessagesCleanServiceIntegration:
created_by=account.id,
updated_by=account.id,
)
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def _create_conversation(self, app):
def _create_conversation(self, db_session_with_containers: Session, app):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,
@ -168,12 +168,14 @@ class TestMessagesCleanServiceIntegration:
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
return conversation
def _create_message(self, app, conversation, created_at=None, with_relations=True):
def _create_message(
self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True
):
"""Helper to create a message with optional related records."""
if created_at is None:
created_at = datetime.datetime.now()
@ -197,16 +199,16 @@ class TestMessagesCleanServiceIntegration:
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)
db.session.add(message)
db.session.flush()
db_session_with_containers.add(message)
db_session_with_containers.flush()
if with_relations:
self._create_message_relations(message)
self._create_message_relations(db_session_with_containers, message)
db.session.commit()
db_session_with_containers.commit()
return message
def _create_message_relations(self, message):
def _create_message_relations(self, db_session_with_containers: Session, message):
"""Helper to create all message-related records."""
# MessageFeedback
feedback = MessageFeedback(
@ -217,7 +219,7 @@ class TestMessagesCleanServiceIntegration:
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
db.session.add(feedback)
db_session_with_containers.add(feedback)
# MessageAnnotation
annotation = MessageAnnotation(
@ -228,7 +230,7 @@ class TestMessagesCleanServiceIntegration:
content="Test annotation",
account_id=message.from_account_id,
)
db.session.add(annotation)
db_session_with_containers.add(annotation)
# MessageChain
chain = MessageChain(
@ -237,8 +239,8 @@ class TestMessagesCleanServiceIntegration:
input=json.dumps({"test": "input"}),
output=json.dumps({"test": "output"}),
)
db.session.add(chain)
db.session.flush()
db_session_with_containers.add(chain)
db_session_with_containers.flush()
# MessageFile
file = MessageFile(
@ -250,7 +252,7 @@ class TestMessagesCleanServiceIntegration:
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)
db.session.add(file)
db_session_with_containers.add(file)
# SavedMessage
saved = SavedMessage(
@ -259,9 +261,9 @@ class TestMessagesCleanServiceIntegration:
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)
db.session.add(saved)
db_session_with_containers.add(saved)
db.session.flush()
db_session_with_containers.flush()
# AppAnnotationHitHistory
hit = AppAnnotationHitHistory(
@ -275,7 +277,7 @@ class TestMessagesCleanServiceIntegration:
annotation_question="Test annotation question",
annotation_content="Test annotation content",
)
db.session.add(hit)
db_session_with_containers.add(hit)
# DatasetRetrieverResource
resource = DatasetRetrieverResource(
@ -296,25 +298,29 @@ class TestMessagesCleanServiceIntegration:
retriever_from="dataset",
created_by=message.from_account_id,
)
db.session.add(resource)
db_session_with_containers.add(resource)
def test_billing_disabled_deletes_all_messages_in_time_range(
self, db_session_with_containers, mock_billing_disabled
self, db_session_with_containers: Session, mock_billing_disabled
):
"""Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan."""
# Arrange - Create tenant with messages (plan doesn't matter for billing disabled)
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create messages: in-range (should be deleted) and out-of-range (should be kept)
in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0)
out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0)
in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True)
in_range_msg = self._create_message(
db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True
)
in_range_msg_id = in_range_msg.id
out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True)
out_of_range_msg = self._create_message(
db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True
)
out_of_range_msg_id = out_of_range_msg.id
# Act - create_message_clean_policy should return BillingDisabledPolicy
@ -336,17 +342,34 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 1
# In-range message deleted
assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0
assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0
# Out-of-range message kept
assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1
# Related records of in-range message deleted
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0
assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0
assert (
db_session_with_containers.query(MessageFeedback)
.where(MessageFeedback.message_id == in_range_msg_id)
.count()
== 0
)
assert (
db_session_with_containers.query(MessageAnnotation)
.where(MessageAnnotation.message_id == in_range_msg_id)
.count()
== 0
)
# Related records of out-of-range message kept
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1
assert (
db_session_with_containers.query(MessageFeedback)
.where(MessageFeedback.message_id == out_of_range_msg_id)
.count()
== 1
)
def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_no_messages_returns_empty_stats(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test cleaning when there are no messages to delete (B1)."""
# Arrange
end_before = datetime.datetime.now() - datetime.timedelta(days=30)
@ -371,36 +394,42 @@ class TestMessagesCleanServiceIntegration:
assert stats["filtered_messages"] == 0
assert stats["total_deleted"] == 0
def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_mixed_sandbox_and_paid_tenants(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test cleaning with mixed sandbox and paid tenants (B2)."""
# Arrange - Create sandbox tenants with expired messages
sandbox_tenants = []
sandbox_message_ids = []
for i in range(2):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
sandbox_tenants.append(tenant)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create 3 expired messages per sandbox tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j)
)
sandbox_message_ids.append(msg.id)
# Create paid tenants with expired messages (should NOT be deleted)
paid_tenants = []
paid_message_ids = []
for i in range(2):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL)
paid_tenants.append(tenant)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create 2 expired messages per paid tenant
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
for j in range(2):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j))
msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j)
)
paid_message_ids.append(msg.id)
# Mock billing service - return plan and expiration_date
@ -442,29 +471,39 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 6
# Only sandbox messages should be deleted
assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0
# Paid messages should remain
assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4
# Related records of sandbox messages should be deleted
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0
assert (
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count()
db_session_with_containers.query(MessageFeedback)
.where(MessageFeedback.message_id.in_(sandbox_message_ids))
.count()
== 0
)
assert (
db_session_with_containers.query(MessageAnnotation)
.where(MessageAnnotation.message_id.in_(sandbox_message_ids))
.count()
== 0
)
def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_cursor_pagination_multiple_batches(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test cursor pagination works correctly across multiple batches (B3)."""
# Arrange - Create sandbox tenant with messages that will span multiple batches
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create 10 expired messages with different timestamps
base_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = []
for i in range(10):
msg = self._create_message(
db_session_with_containers,
app,
conv,
created_at=base_date + datetime.timedelta(hours=i),
@ -498,20 +537,22 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 10
# All messages should be deleted
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0
assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0
def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
"""Test dry_run mode does not delete messages (B4)."""
# Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create expired messages
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
message_ids = []
for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i)
)
message_ids.append(msg.id)
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
@ -540,21 +581,26 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0 # But NOT deleted
# All messages should still exist
assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3
assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3
# Related records should also still exist
assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3
assert (
db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count()
== 3
)
def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_partial_plan_data_safe_default(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test when billing returns partial data, unknown tenants are preserved (B5)."""
# Arrange - Create 3 tenants
tenants_data = []
for i in range(3):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date)
msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date)
tenants_data.append(
{
@ -600,28 +646,30 @@ class TestMessagesCleanServiceIntegration:
# Check which messages were deleted
assert (
db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0
) # Sandbox tenant's message deleted
assert (
db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
) # Professional tenant's message preserved
assert (
db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1
) # Unknown tenant's message preserved (safe default)
def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_empty_plan_data_skips_deletion(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test when billing returns empty data, skip deletion entirely (B6)."""
# Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date)
msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date)
msg_id = msg.id
db.session.commit()
db_session_with_containers.commit()
# Mock billing service to return empty data (simulating failure/no data scenario)
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
@ -644,17 +692,20 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0
# Message should still exist (safe default - don't delete if plan is unknown)
assert db.session.query(Message).where(Message.id == msg_id).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1
def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_time_range_boundary_behavior(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test that messages are correctly filtered by [start_from, end_before) time range (B7)."""
# Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create messages: before range, in range, after range
msg_before = self._create_message(
db_session_with_containers,
app,
conv,
created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from
@ -663,6 +714,7 @@ class TestMessagesCleanServiceIntegration:
msg_before_id = msg_before.id
msg_at_start = self._create_message(
db_session_with_containers,
app,
conv,
created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive)
@ -671,6 +723,7 @@ class TestMessagesCleanServiceIntegration:
msg_at_start_id = msg_at_start.id
msg_in_range = self._create_message(
db_session_with_containers,
app,
conv,
created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range
@ -679,6 +732,7 @@ class TestMessagesCleanServiceIntegration:
msg_in_range_id = msg_in_range.id
msg_at_end = self._create_message(
db_session_with_containers,
app,
conv,
created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive)
@ -687,6 +741,7 @@ class TestMessagesCleanServiceIntegration:
msg_at_end_id = msg_at_end.id
msg_after = self._create_message(
db_session_with_containers,
app,
conv,
created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before
@ -694,7 +749,7 @@ class TestMessagesCleanServiceIntegration:
)
msg_after_id = msg_after.id
db.session.commit()
db_session_with_containers.commit()
# Mock billing service
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
@ -722,17 +777,17 @@ class TestMessagesCleanServiceIntegration:
# Verify specific messages using stored IDs
# Before range, kept
assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1
# At start (inclusive), deleted
assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0
assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0
# In range, deleted
assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0
assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0
# At end (exclusive), kept
assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1
# After range, kept
assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1
def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
"""Test cleaning with different graceful period scenarios (B8)."""
# Arrange - Create 5 different tenants with different plan and expiration scenarios
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
@ -740,50 +795,60 @@ class TestMessagesCleanServiceIntegration:
# Scenario 1: Sandbox plan with expiration within graceful period (5 days ago)
# Should NOT be deleted
account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app1 = self._create_app(tenant1, account1)
conv1 = self._create_conversation(app1)
account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app1 = self._create_app(db_session_with_containers, tenant1, account1)
conv1 = self._create_conversation(db_session_with_containers, app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
msg1 = self._create_message(
db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False
)
msg1_id = msg1.id
expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period
# Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago)
# Should be deleted
account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app2 = self._create_app(tenant2, account2)
conv2 = self._create_conversation(app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app2 = self._create_app(db_session_with_containers, tenant2, account2)
conv2 = self._create_conversation(db_session_with_containers, app2)
msg2 = self._create_message(
db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False
)
msg2_id = msg2.id
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period
# Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription)
# Should be deleted
account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app3 = self._create_app(tenant3, account3)
conv3 = self._create_conversation(app3)
msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False)
account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app3 = self._create_app(db_session_with_containers, tenant3, account3)
conv3 = self._create_conversation(db_session_with_containers, app3)
msg3 = self._create_message(
db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False
)
msg3_id = msg3.id
# Scenario 4: Non-sandbox plan (professional) with no expiration (future date)
# Should NOT be deleted
account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL)
app4 = self._create_app(tenant4, account4)
conv4 = self._create_conversation(app4)
msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False)
account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL)
app4 = self._create_app(db_session_with_containers, tenant4, account4)
conv4 = self._create_conversation(db_session_with_containers, app4)
msg4 = self._create_message(
db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False
)
msg4_id = msg4.id
future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year
# Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago)
# Should NOT be deleted (boundary is exclusive: > graceful_period)
account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app5 = self._create_app(tenant5, account5)
conv5 = self._create_conversation(app5)
msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False)
account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app5 = self._create_app(db_session_with_containers, tenant5, account5)
conv5 = self._create_conversation(db_session_with_containers, app5)
msg5 = self._create_message(
db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False
)
msg5_id = msg5.id
expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary
db.session.commit()
db_session_with_containers.commit()
# Mock billing service with all scenarios
plan_map = {
@ -832,23 +897,31 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 2
# Verify each scenario using saved IDs
assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted
assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted
assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept
assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept
assert (
db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0
) # Beyond grace, deleted
assert (
db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0
) # No subscription, deleted
assert (
db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1
) # Professional plan, kept
assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept
def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist):
"""Test that whitelisted tenants' messages are not deleted (B9)."""
# Arrange - Create 3 sandbox tenants with expired messages
tenants_data = []
for i in range(3):
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg = self._create_message(app, conv, created_at=expired_date, with_relations=False)
msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date, with_relations=False
)
tenants_data.append(
{
@ -897,27 +970,33 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 1
# Verify tenant0's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1
# Verify tenant1's message still exists (whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1
# Verify tenant2's message was deleted (not whitelisted)
assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0
def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist):
def test_from_days_cleans_old_messages(
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test from_days correctly cleans messages older than N days (B11)."""
# Arrange
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
# Create old messages (should be deleted - older than 30 days)
old_date = datetime.datetime.now() - datetime.timedelta(days=45)
old_msg_ids = []
for i in range(3):
msg = self._create_message(
app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False
db_session_with_containers,
app,
conv,
created_at=old_date - datetime.timedelta(hours=i),
with_relations=False,
)
old_msg_ids.append(msg.id)
@ -926,11 +1005,15 @@ class TestMessagesCleanServiceIntegration:
recent_msg_ids = []
for i in range(2):
msg = self._create_message(
app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False
db_session_with_containers,
app,
conv,
created_at=recent_date - datetime.timedelta(hours=i),
with_relations=False,
)
recent_msg_ids.append(msg.id)
db.session.commit()
db_session_with_containers.commit()
with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing:
mock_billing.return_value = {
@ -955,30 +1038,34 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 3
# Old messages deleted
assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0
assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0
# Recent messages kept
assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2
assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2
def test_whitelist_precedence_over_grace_period(
self, db_session_with_containers, mock_billing_enabled, mock_whitelist
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test that whitelist takes precedence over grace period logic."""
# Arrange - Create 2 sandbox tenants
now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp())
# Tenant1: whitelisted, expired beyond grace period
account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app1 = self._create_app(tenant1, account1)
conv1 = self._create_conversation(app1)
account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app1 = self._create_app(db_session_with_containers, tenant1, account1)
conv1 = self._create_conversation(db_session_with_containers, app1)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False)
msg1 = self._create_message(
db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False
)
expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace
# Tenant2: not whitelisted, within grace period
account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app2 = self._create_app(tenant2, account2)
conv2 = self._create_conversation(app2)
msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False)
account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app2 = self._create_app(db_session_with_containers, tenant2, account2)
conv2 = self._create_conversation(db_session_with_containers, app2)
msg2 = self._create_message(
db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False
)
expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace
# Mock billing service
@ -1019,22 +1106,26 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 0
# Verify both messages still exist
assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period
assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted
assert (
db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1
) # Within grace period
def test_empty_whitelist_deletes_eligible_messages(
self, db_session_with_containers, mock_billing_enabled, mock_whitelist
self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist
):
"""Test that empty whitelist behaves as no whitelist (all eligible messages deleted)."""
# Arrange - Create sandbox tenant with expired messages
account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX)
app = self._create_app(tenant, account)
conv = self._create_conversation(app)
account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX)
app = self._create_app(db_session_with_containers, tenant, account)
conv = self._create_conversation(db_session_with_containers, app)
expired_date = datetime.datetime.now() - datetime.timedelta(days=35)
msg_ids = []
for i in range(3):
msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i))
msg = self._create_message(
db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i)
)
msg_ids.append(msg.id)
# Mock billing service
@ -1068,4 +1159,4 @@ class TestMessagesCleanServiceIntegration:
assert stats["total_deleted"] == 3
# Verify all messages were deleted
assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0
assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0

View File

@ -2,6 +2,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.built_in_field import BuiltInField
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -32,7 +33,7 @@ class TestMetadataService:
"document_service": mock_document_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -53,18 +54,16 @@ class TestMetadataService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -73,15 +72,17 @@ class TestMetadataService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant):
def _create_test_dataset(
self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant
):
"""
Helper method to create a test dataset for testing.
@ -105,14 +106,14 @@ class TestMetadataService:
built_in_field_enabled=False,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account):
def _create_test_document(
self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account
):
"""
Helper method to create a test document for testing.
@ -141,14 +142,12 @@ class TestMetadataService:
doc_language="en",
)
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful metadata creation with valid parameters.
"""
@ -178,13 +177,14 @@ class TestMetadataService:
assert result.created_by == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.created_at is not None
def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_metadata_name_too_long(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata creation fails when name exceeds 255 characters.
"""
@ -207,7 +207,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
MetadataService.create_metadata(dataset.id, metadata_args)
def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_metadata_name_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata creation fails when name already exists in the same dataset.
"""
@ -235,7 +237,7 @@ class TestMetadataService:
MetadataService.create_metadata(dataset.id, second_metadata_args)
def test_create_metadata_name_conflicts_with_built_in_field(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata creation fails when name conflicts with built-in field names.
@ -260,7 +262,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.create_metadata(dataset.id, metadata_args)
def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_metadata_name_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful metadata name update with valid parameters.
"""
@ -291,12 +295,13 @@ class TestMetadataService:
assert result.updated_at is not None
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.name == new_name
def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_metadata_name_too_long(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata name update fails when new name exceeds 255 characters.
"""
@ -323,7 +328,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."):
MetadataService.update_metadata_name(dataset.id, metadata.id, long_name)
def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_metadata_name_already_exists(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata name update fails when new name already exists in the same dataset.
"""
@ -351,7 +358,7 @@ class TestMetadataService:
MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata")
def test_update_metadata_name_conflicts_with_built_in_field(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata name update fails when new name conflicts with built-in field names.
@ -378,7 +385,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."):
MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name)
def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_metadata_name_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata name update fails when metadata ID does not exist.
"""
@ -406,7 +415,7 @@ class TestMetadataService:
# Assert: Verify the method returns None when metadata is not found
assert result is None
def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful metadata deletion with valid parameters.
"""
@ -434,12 +443,11 @@ class TestMetadataService:
assert result.id == metadata.id
# Verify metadata was deleted from database
from extensions.ext_database import db
deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first()
deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
assert deleted_metadata is None
def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test metadata deletion fails when metadata ID does not exist.
"""
@ -467,7 +475,7 @@ class TestMetadataService:
assert result is None
def test_delete_metadata_with_document_bindings(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata deletion successfully removes document metadata bindings.
@ -500,15 +508,13 @@ class TestMetadataService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(binding)
db_session_with_containers.commit()
# Set document metadata
document.doc_metadata = {"test_metadata": "test_value"}
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act: Execute the method under test
result = MetadataService.delete_metadata(dataset.id, metadata.id)
@ -517,13 +523,13 @@ class TestMetadataService:
assert result is not None
# Verify metadata was deleted from database
deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first()
deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first()
assert deleted_metadata is None
# Note: The service attempts to update document metadata but may not succeed
# due to mock configuration. The main functionality (metadata deletion) is verified.
def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of built-in metadata fields.
"""
@ -548,7 +554,9 @@ class TestMetadataService:
assert "string" in field_types
assert "time" in field_types
def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_enable_built_in_field_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful enabling of built-in fields for a dataset.
"""
@ -579,16 +587,15 @@ class TestMetadataService:
MetadataService.enable_built_in_field(dataset)
# Assert: Verify the expected outcomes
from extensions.ext_database import db
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True
# Note: Document metadata update depends on DocumentService mock working correctly
# The main functionality (enabling built-in fields) is verified
def test_enable_built_in_field_already_enabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test enabling built-in fields when they are already enabled.
@ -607,10 +614,9 @@ class TestMetadataService:
# Enable built-in fields first
dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
@ -619,11 +625,11 @@ class TestMetadataService:
MetadataService.enable_built_in_field(dataset)
# Assert: Verify the method returns early without changes
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True
def test_enable_built_in_field_with_no_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test enabling built-in fields for a dataset with no documents.
@ -647,12 +653,13 @@ class TestMetadataService:
MetadataService.enable_built_in_field(dataset)
# Assert: Verify the expected outcomes
from extensions.ext_database import db
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is True
def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_disable_built_in_field_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful disabling of built-in fields for a dataset.
"""
@ -673,10 +680,9 @@ class TestMetadataService:
# Enable built-in fields first
dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Set document metadata with built-in fields
document.doc_metadata = {
@ -686,8 +692,8 @@ class TestMetadataService:
BuiltInField.last_update_date: 1234567890.0,
BuiltInField.source: "test_source",
}
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [
@ -698,14 +704,14 @@ class TestMetadataService:
MetadataService.disable_built_in_field(dataset)
# Assert: Verify the expected outcomes
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False
# Note: Document metadata update depends on DocumentService mock working correctly
# The main functionality (disabling built-in fields) is verified
def test_disable_built_in_field_already_disabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test disabling built-in fields when they are already disabled.
@ -732,13 +738,12 @@ class TestMetadataService:
MetadataService.disable_built_in_field(dataset)
# Assert: Verify the method returns early without changes
from extensions.ext_database import db
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False
def test_disable_built_in_field_with_no_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test disabling built-in fields for a dataset with no documents.
@ -757,10 +762,9 @@ class TestMetadataService:
# Enable built-in fields first
dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Mock DocumentService.get_working_documents_by_dataset_id to return empty list
mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = []
@ -769,10 +773,12 @@ class TestMetadataService:
MetadataService.disable_built_in_field(dataset)
# Assert: Verify the expected outcomes
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
assert dataset.built_in_field_enabled is False
def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_documents_metadata_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful update of documents metadata.
"""
@ -815,24 +821,25 @@ class TestMetadataService:
MetadataService.update_documents_metadata(dataset, operation_data)
# Assert: Verify the expected outcomes
from extensions.ext_database import db
# Verify document metadata was updated
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.doc_metadata is not None
assert "test_metadata" in document.doc_metadata
assert document.doc_metadata["test_metadata"] == "test_value"
# Verify metadata binding was created
binding = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first()
db_session_with_containers.query(DatasetMetadataBinding)
.filter_by(metadata_id=metadata.id, document_id=document.id)
.first()
)
assert binding is not None
assert binding.tenant_id == tenant.id
assert binding.dataset_id == dataset.id
def test_update_documents_metadata_with_built_in_fields_enabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update of documents metadata when built-in fields are enabled.
@ -850,10 +857,9 @@ class TestMetadataService:
# Enable built-in fields
dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
@ -884,7 +890,7 @@ class TestMetadataService:
# Assert: Verify the expected outcomes
# Verify document metadata was updated with both custom and built-in fields
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.doc_metadata is not None
assert "test_metadata" in document.doc_metadata
assert document.doc_metadata["test_metadata"] == "test_value"
@ -893,7 +899,7 @@ class TestMetadataService:
# The main functionality (custom metadata update) is verified
def test_update_documents_metadata_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update of documents metadata when document is not found.
@ -936,7 +942,7 @@ class TestMetadataService:
MetadataService.update_documents_metadata(dataset, operation_data)
def test_knowledge_base_metadata_lock_check_dataset_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata lock check for dataset operations.
@ -959,7 +965,7 @@ class TestMetadataService:
assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}"
def test_knowledge_base_metadata_lock_check_document_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata lock check for document operations.
@ -982,7 +988,7 @@ class TestMetadataService:
assert call_args[0][0] == f"document_metadata_lock_{document_id}"
def test_knowledge_base_metadata_lock_check_lock_exists(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata lock check when lock already exists.
@ -999,7 +1005,7 @@ class TestMetadataService:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
def test_knowledge_base_metadata_lock_check_document_lock_exists(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test metadata lock check when document lock already exists.
@ -1013,7 +1019,9 @@ class TestMetadataService:
with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."):
MetadataService.knowledge_base_metadata_lock_check(None, document_id)
def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_dataset_metadatas_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of dataset metadata information.
"""
@ -1046,10 +1054,8 @@ class TestMetadataService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(binding)
db_session_with_containers.commit()
# Act: Execute the method under test
result = MetadataService.get_dataset_metadatas(dataset)
@ -1071,7 +1077,7 @@ class TestMetadataService:
assert result["built_in_field_enabled"] is False
def test_get_dataset_metadatas_with_built_in_fields_enabled(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of dataset metadata when built-in fields are enabled.
@ -1086,10 +1092,9 @@ class TestMetadataService:
# Enable built-in fields
dataset.built_in_field_enabled = True
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Setup mocks
mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id
@ -1114,7 +1119,9 @@ class TestMetadataService:
# Verify built-in field status
assert result["built_in_field_enabled"] is True
def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_dataset_metadatas_no_metadata(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of dataset metadata when no metadata exists.
"""

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from models.account import TenantAccountJoin, TenantAccountRole
from models.model import Account, Tenant
@ -67,7 +68,7 @@ class TestModelLoadBalancingService:
"credential_schema": mock_credential_schema,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -88,18 +89,16 @@ class TestModelLoadBalancingService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -108,8 +107,8 @@ class TestModelLoadBalancingService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -117,7 +116,7 @@ class TestModelLoadBalancingService:
return account, tenant
def _create_test_provider_and_setting(
self, db_session_with_containers, tenant_id, mock_external_service_dependencies
self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies
):
"""
Helper method to create a test provider and provider model setting.
@ -132,8 +131,6 @@ class TestModelLoadBalancingService:
"""
fake = Faker()
from extensions.ext_database import db
# Create provider
provider = Provider(
tenant_id=tenant_id,
@ -141,8 +138,8 @@ class TestModelLoadBalancingService:
provider_type="custom",
is_valid=True,
)
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Create provider model setting
provider_model_setting = ProviderModelSetting(
@ -153,12 +150,14 @@ class TestModelLoadBalancingService:
enabled=True,
load_balancing_enabled=False,
)
db.session.add(provider_model_setting)
db.session.commit()
db_session_with_containers.add(provider_model_setting)
db_session_with_containers.commit()
return provider, provider_model_setting
def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_enable_model_load_balancing_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful model load balancing enablement.
@ -193,14 +192,15 @@ class TestModelLoadBalancingService:
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
db_session_with_containers.refresh(provider)
db_session_with_containers.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_disable_model_load_balancing_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful model load balancing disablement.
@ -235,15 +235,14 @@ class TestModelLoadBalancingService:
assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value
# Verify database state
from extensions.ext_database import db
db.session.refresh(provider)
db.session.refresh(provider_model_setting)
db_session_with_containers.refresh(provider)
db_session_with_containers.refresh(provider_model_setting)
assert provider.id is not None
assert provider_model_setting.id is not None
def test_enable_model_load_balancing_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist.
@ -275,11 +274,12 @@ class TestModelLoadBalancingService:
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
db_session_with_containers.rollback()
def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_load_balancing_configs_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of load balancing configurations.
@ -298,7 +298,6 @@ class TestModelLoadBalancingService:
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
@ -309,11 +308,11 @@ class TestModelLoadBalancingService:
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
db_session_with_containers.add(load_balancing_config)
db_session_with_containers.commit()
# Verify the config was created
db.session.refresh(load_balancing_config)
db_session_with_containers.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Setup mocks for get_load_balancing_configs method
@ -358,11 +357,11 @@ class TestModelLoadBalancingService:
assert configs[0]["ttl"] == 0
# Verify database state
db.session.refresh(load_balancing_config)
db_session_with_containers.refresh(load_balancing_config)
assert load_balancing_config.id is not None
def test_get_load_balancing_configs_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when provider does not exist in get_load_balancing_configs.
@ -394,12 +393,11 @@ class TestModelLoadBalancingService:
assert "Provider nonexistent_provider does not exist." in str(exc_info.value)
# Verify no database state changes occurred
from extensions.ext_database import db
db.session.rollback()
db_session_with_containers.rollback()
def test_get_load_balancing_configs_with_inherit_config(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test load balancing configs retrieval with inherit configuration.
@ -419,7 +417,6 @@ class TestModelLoadBalancingService:
)
# Create load balancing config
from extensions.ext_database import db
load_balancing_config = LoadBalancingModelConfig(
tenant_id=tenant.id,
@ -430,8 +427,8 @@ class TestModelLoadBalancingService:
encrypted_config='{"api_key": "test_key"}',
enabled=True,
)
db.session.add(load_balancing_config)
db.session.commit()
db_session_with_containers.add(load_balancing_config)
db_session_with_containers.commit()
# Setup mocks for inherit config scenario
mock_provider_config = mock_external_service_dependencies["provider_config"]
@ -467,11 +464,11 @@ class TestModelLoadBalancingService:
assert configs[1]["name"] == "config1"
# Verify database state
db.session.refresh(load_balancing_config)
db_session_with_containers.refresh(load_balancing_config)
assert load_balancing_config.id is not None
# Verify inherit config was created in database
inherit_configs = db.session.scalars(
inherit_configs = db_session_with_containers.scalars(
select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__")
).all()
assert len(inherit_configs) == 1

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.entities.model_entities import ModelStatus
from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType
@ -29,7 +30,7 @@ class TestModelProviderService:
"model_provider_factory": mock_model_provider_factory,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -50,18 +51,16 @@ class TestModelProviderService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -70,8 +69,8 @@ class TestModelProviderService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -80,7 +79,7 @@ class TestModelProviderService:
def _create_test_provider(
self,
db_session_with_containers,
db_session_with_containers: Session,
mock_external_service_dependencies,
tenant_id: str,
provider_name: str = "openai",
@ -109,16 +108,14 @@ class TestModelProviderService:
quota_used=0,
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
return provider
def _create_test_provider_model(
self,
db_session_with_containers,
db_session_with_containers: Session,
mock_external_service_dependencies,
tenant_id: str,
provider_name: str,
@ -149,16 +146,14 @@ class TestModelProviderService:
is_valid=True,
)
from extensions.ext_database import db
db.session.add(provider_model)
db.session.commit()
db_session_with_containers.add(provider_model)
db_session_with_containers.commit()
return provider_model
def _create_test_provider_model_setting(
self,
db_session_with_containers,
db_session_with_containers: Session,
mock_external_service_dependencies,
tenant_id: str,
provider_name: str,
@ -190,14 +185,12 @@ class TestModelProviderService:
load_balancing_enabled=False,
)
from extensions.ext_database import db
db.session.add(provider_model_setting)
db.session.commit()
db_session_with_containers.add(provider_model_setting)
db_session_with_containers.commit()
return provider_model_setting
def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful provider list retrieval.
@ -275,7 +268,7 @@ class TestModelProviderService:
mock_provider_config.is_custom_configuration_available.assert_called_once()
def test_get_provider_list_with_model_type_filter(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test provider list retrieval with model type filtering.
@ -374,7 +367,9 @@ class TestModelProviderService:
assert result[0].provider == "cohere"
assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types
def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_models_by_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of models by provider.
@ -485,7 +480,9 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_configurations.get_models.assert_called_once_with(provider="openai")
def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_provider_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of provider credentials.
@ -543,7 +540,7 @@ class TestModelProviderService:
mock_method.assert_called_once_with(tenant.id, "openai")
def test_provider_credentials_validate_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful validation of provider credentials.
@ -585,7 +582,7 @@ class TestModelProviderService:
mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials)
def test_provider_credentials_validate_invalid_provider(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test validation failure for non-existent provider.
@ -617,7 +614,7 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
def test_get_default_model_of_model_type_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of default model for a specific model type.
@ -673,7 +670,7 @@ class TestModelProviderService:
mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM)
def test_update_default_model_of_model_type_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful update of default model for a specific model type.
@ -706,7 +703,9 @@ class TestModelProviderService:
tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4"
)
def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_model_provider_icon_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of model provider icon.
@ -743,7 +742,9 @@ class TestModelProviderService:
# Verify mock interactions
mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US")
def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_switch_preferred_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful switching of preferred provider type.
@ -779,7 +780,7 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configuration.switch_preferred_provider_type.assert_called_once()
def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful enabling of a model.
@ -815,7 +816,9 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4")
def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_model_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of model credentials.
@ -872,7 +875,9 @@ class TestModelProviderService:
# Verify the method was called with correct parameters
mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None)
def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_model_credentials_validate_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful validation of model credentials.
@ -914,7 +919,9 @@ class TestModelProviderService:
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials
)
def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_model_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful saving of model credentials.
@ -955,7 +962,9 @@ class TestModelProviderService:
model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname"
)
def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_remove_model_credentials_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful removal of model credentials.
@ -993,7 +1002,9 @@ class TestModelProviderService:
model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6"
)
def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_models_by_model_type_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of models by model type.
@ -1070,7 +1081,9 @@ class TestModelProviderService:
mock_provider_manager.get_configurations.assert_called_once_with(tenant.id)
mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_model_parameter_rules_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of model parameter rules.
@ -1137,7 +1150,7 @@ class TestModelProviderService:
)
def test_get_model_parameter_rules_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test parameter rules retrieval when no credentials are available.
@ -1181,7 +1194,7 @@ class TestModelProviderService:
)
def test_get_model_parameter_rules_provider_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test parameter rules retrieval when provider does not exist.

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.model import EndUser, Message
from models.web import SavedMessage
@ -38,7 +39,7 @@ class TestSavedMessageService:
"message_service": mock_message_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -85,7 +86,7 @@ class TestSavedMessageService:
return app, account
def _create_test_end_user(self, db_session_with_containers, app):
def _create_test_end_user(self, db_session_with_containers: Session, app):
"""
Helper method to create a test end user for testing.
@ -108,14 +109,12 @@ class TestSavedMessageService:
is_anonymous=False,
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
return end_user
def _create_test_message(self, db_session_with_containers, app, user):
def _create_test_message(self, db_session_with_containers: Session, app, user):
"""
Helper method to create a test message for testing.
@ -143,10 +142,8 @@ class TestSavedMessageService:
mode="chat",
)
from extensions.ext_database import db
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
# Create message
message = Message(
@ -168,13 +165,13 @@ class TestSavedMessageService:
status="success",
)
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
return message
def test_pagination_by_last_id_success_with_account_user(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with account user.
@ -207,10 +204,8 @@ class TestSavedMessageService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
db_session_with_containers.add_all([saved_message1, saved_message2])
db_session_with_containers.commit()
# Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -240,15 +235,15 @@ class TestSavedMessageService:
assert actual_include_ids == expected_include_ids
# Verify database state
db.session.refresh(saved_message1)
db.session.refresh(saved_message2)
db_session_with_containers.refresh(saved_message1)
db_session_with_containers.refresh(saved_message2)
assert saved_message1.id is not None
assert saved_message2.id is not None
assert saved_message1.created_by_role == "account"
assert saved_message2.created_by_role == "account"
def test_pagination_by_last_id_success_with_end_user(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with end user.
@ -282,10 +277,8 @@ class TestSavedMessageService:
created_by=end_user.id,
)
from extensions.ext_database import db
db.session.add_all([saved_message1, saved_message2])
db.session.commit()
db_session_with_containers.add_all([saved_message1, saved_message2])
db_session_with_containers.commit()
# Mock MessageService.pagination_by_last_id return value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -317,14 +310,16 @@ class TestSavedMessageService:
assert actual_include_ids == expected_include_ids
# Verify database state
db.session.refresh(saved_message1)
db.session.refresh(saved_message2)
db_session_with_containers.refresh(saved_message1)
db_session_with_containers.refresh(saved_message2)
assert saved_message1.id is not None
assert saved_message2.id is not None
assert saved_message1.created_by_role == "end_user"
assert saved_message2.created_by_role == "end_user"
def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_success_with_new_message(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful save of a new message.
@ -347,10 +342,9 @@ class TestSavedMessageService:
# Assert: Verify the expected outcomes
# Check if saved message was created in database
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -373,10 +367,12 @@ class TestSavedMessageService:
)
# Verify database state
db.session.refresh(saved_message)
db_session_with_containers.refresh(saved_message)
assert saved_message.id is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_last_id_error_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when no user is provided.
@ -396,12 +392,11 @@ class TestSavedMessageService:
assert "User is required" in str(exc_info.value)
# Verify no database operations were performed
from extensions.ext_database import db
saved_messages = db.session.query(SavedMessage).all()
saved_messages = db_session_with_containers.query(SavedMessage).all()
assert len(saved_messages) == 0
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
@ -422,10 +417,9 @@ class TestSavedMessageService:
assert result is None
# Verify no saved message was created
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -435,7 +429,9 @@ class TestSavedMessageService:
assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_success_existing_message(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful deletion of an existing saved message.
@ -457,14 +453,12 @@ class TestSavedMessageService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(saved_message)
db.session.commit()
db_session_with_containers.add(saved_message)
db_session_with_containers.commit()
# Verify saved message exists
assert (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -481,7 +475,7 @@ class TestSavedMessageService:
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -494,11 +488,13 @@ class TestSavedMessageService:
assert deleted_saved_message is None
# Verify database state
db.session.commit()
db_session_with_containers.commit()
# The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None
def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_last_id_error_no_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when no user is provided.
@ -522,7 +518,7 @@ class TestSavedMessageService:
# Instead, we verify that the error was properly raised
pass
def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when saving message with no user.
@ -543,10 +539,9 @@ class TestSavedMessageService:
assert result is None
# Verify no saved message was created
from extensions.ext_database import db
saved_message = (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -556,7 +551,9 @@ class TestSavedMessageService:
assert saved_message is None
def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_success_existing_message(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful deletion of an existing saved message.
@ -578,14 +575,12 @@ class TestSavedMessageService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(saved_message)
db.session.commit()
db_session_with_containers.add(saved_message)
db_session_with_containers.commit()
# Verify saved message exists
assert (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -602,7 +597,7 @@ class TestSavedMessageService:
# Assert: Verify the expected outcomes
# Check if saved message was deleted from database
deleted_saved_message = (
db.session.query(SavedMessage)
db_session_with_containers.query(SavedMessage)
.where(
SavedMessage.app_id == app.id,
SavedMessage.message_id == message.id,
@ -615,6 +610,6 @@ class TestSavedMessageService:
assert deleted_saved_message is None
# Verify database state
db.session.commit()
db_session_with_containers.commit()
# The message should still exist, only the saved_message should be deleted
assert db.session.query(Message).where(Message.id == message.id).first() is not None
assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None

View File

@ -4,6 +4,7 @@ from unittest.mock import create_autospec, patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@ -29,7 +30,7 @@ class TestTagService:
"current_user": mock_current_user,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -50,18 +51,16 @@ class TestTagService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -70,8 +69,8 @@ class TestTagService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -82,7 +81,7 @@ class TestTagService:
return account, tenant
def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id):
def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id):
"""
Helper method to create a test dataset for testing.
@ -107,14 +106,12 @@ class TestTagService:
created_by=mock_external_service_dependencies["current_user"].id,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id):
def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id):
"""
Helper method to create a test app for testing.
@ -141,15 +138,13 @@ class TestTagService:
created_by=mock_external_service_dependencies["current_user"].id,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def _create_test_tags(
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3
):
"""
Helper method to create test tags for testing.
@ -176,16 +171,14 @@ class TestTagService:
)
tags.append(tag)
from extensions.ext_database import db
for tag in tags:
db.session.add(tag)
db.session.commit()
db_session_with_containers.add(tag)
db_session_with_containers.commit()
return tags
def _create_test_tag_bindings(
self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id
self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id
):
"""
Helper method to create test tag bindings for testing.
@ -211,15 +204,13 @@ class TestTagService:
)
tag_bindings.append(tag_binding)
from extensions.ext_database import db
for tag_binding in tag_bindings:
db.session.add(tag_binding)
db.session.commit()
db_session_with_containers.add(tag_binding)
db_session_with_containers.commit()
return tag_bindings
def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of tags with binding count.
@ -270,7 +261,9 @@ class TestTagService:
# The ordering is handled by the database, we just verify the results are returned
assert len(result) == 3
def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tags_with_keyword_filter(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag retrieval with keyword filtering.
@ -291,12 +284,11 @@ class TestTagService:
)
# Update tag names to make them searchable
from extensions.ext_database import db
tags[0].name = "python_development"
tags[1].name = "machine_learning"
tags[2].name = "web_development"
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the method under test with keyword filter
result = TagService.get_tags("app", tenant.id, keyword="development")
@ -314,7 +306,7 @@ class TestTagService:
assert len(result_no_match) == 0
def test_get_tags_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
r"""
Test tag retrieval with special characters in keyword to verify SQL injection prevention.
@ -330,8 +322,6 @@ class TestTagService:
db_session_with_containers, mock_external_service_dependencies
)
from extensions.ext_database import db
# Create tags with special characters in names
tag_with_percent = Tag(
name="50% discount",
@ -340,7 +330,7 @@ class TestTagService:
created_by=account.id,
)
tag_with_percent.id = str(uuid.uuid4())
db.session.add(tag_with_percent)
db_session_with_containers.add(tag_with_percent)
tag_with_underscore = Tag(
name="test_data_tag",
@ -349,7 +339,7 @@ class TestTagService:
created_by=account.id,
)
tag_with_underscore.id = str(uuid.uuid4())
db.session.add(tag_with_underscore)
db_session_with_containers.add(tag_with_underscore)
tag_with_backslash = Tag(
name="path\\to\\tag",
@ -358,7 +348,7 @@ class TestTagService:
created_by=account.id,
)
tag_with_backslash.id = str(uuid.uuid4())
db.session.add(tag_with_backslash)
db_session_with_containers.add(tag_with_backslash)
# Create tag that should NOT match
tag_no_match = Tag(
@ -368,9 +358,9 @@ class TestTagService:
created_by=account.id,
)
tag_no_match.id = str(uuid.uuid4())
db.session.add(tag_no_match)
db_session_with_containers.add(tag_no_match)
db.session.commit()
db_session_with_containers.commit()
# Act & Assert: Test 1 - Search with % character
result = TagService.get_tags("app", tenant.id, keyword="50%")
@ -392,7 +382,7 @@ class TestTagService:
assert len(result) == 1
assert all("50%" in item.name for item in result)
def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test tag retrieval when no tags exist.
@ -414,7 +404,9 @@ class TestTagService:
assert len(result) == 0
assert isinstance(result, list)
def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_target_ids_by_tag_ids_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of target IDs by tag IDs.
@ -469,7 +461,7 @@ class TestTagService:
assert second_dataset_count == 1
def test_get_target_ids_by_tag_ids_empty_tag_ids(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test target ID retrieval with empty tag IDs list.
@ -493,7 +485,7 @@ class TestTagService:
assert isinstance(result, list)
def test_get_target_ids_by_tag_ids_no_matching_tags(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test target ID retrieval when no tags match the criteria.
@ -521,7 +513,7 @@ class TestTagService:
assert len(result) == 0
assert isinstance(result, list)
def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of tags by tag name.
@ -542,11 +534,10 @@ class TestTagService:
)
# Update tag names to make them searchable
from extensions.ext_database import db
tags[0].name = "python_tag"
tags[1].name = "ml_tag"
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the method under test
result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag")
@ -558,7 +549,9 @@ class TestTagService:
assert result[0].type == "app"
assert result[0].tenant_id == tenant.id
def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tag_by_tag_name_no_matches(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag retrieval by name when no matches exist.
@ -580,7 +573,9 @@ class TestTagService:
assert len(result) == 0
assert isinstance(result, list)
def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tag_by_tag_name_empty_parameters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag retrieval by name with empty parameters.
@ -605,7 +600,9 @@ class TestTagService:
assert result_empty_name is not None
assert len(result_empty_name) == 0
def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tags_by_target_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of tags by target ID.
@ -644,7 +641,9 @@ class TestTagService:
assert tag.tenant_id == tenant.id
assert tag.id in [t.id for t in tags]
def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tags_by_target_id_no_bindings(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag retrieval by target ID when no tags are bound.
@ -669,7 +668,7 @@ class TestTagService:
assert len(result) == 0
assert isinstance(result, list)
def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful tag creation.
@ -698,17 +697,18 @@ class TestTagService:
assert result.id is not None
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
# Verify tag was actually saved to database
saved_tag = db.session.query(Tag).where(Tag.id == result.id).first()
saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first()
assert saved_tag is not None
assert saved_tag.name == "test_tag_name"
def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_tags_duplicate_name_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag creation with duplicate name.
@ -731,7 +731,7 @@ class TestTagService:
TagService.save_tags(tag_args)
assert "Tag name already exists" in str(exc_info.value)
def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful tag update.
@ -763,17 +763,16 @@ class TestTagService:
assert result.id == tag.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.name == "updated_name"
# Verify tag was actually updated in database
updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first()
updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
assert updated_tag is not None
assert updated_tag.name == "updated_name"
def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test tag update for non-existent tag.
@ -799,7 +798,9 @@ class TestTagService:
TagService.update_tags(update_args, non_existent_tag_id)
assert "Tag not found" in str(exc_info.value)
def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_tags_duplicate_name_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag update with duplicate name.
@ -828,7 +829,9 @@ class TestTagService:
TagService.update_tags(update_args, tag2.id)
assert "Tag name already exists" in str(exc_info.value)
def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tag_binding_count_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of tag binding count.
@ -863,7 +866,7 @@ class TestTagService:
assert result_tag_without_bindings == 0
def test_get_tag_binding_count_non_existent_tag(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test binding count retrieval for non-existent tag.
@ -889,7 +892,7 @@ class TestTagService:
# Assert: Verify the expected outcomes
assert result == 0
def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful tag deletion.
@ -916,12 +919,11 @@ class TestTagService:
)
# Verify tag and binding exist before deletion
from extensions.ext_database import db
tag_before = db.session.query(Tag).where(Tag.id == tag.id).first()
tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
assert tag_before is not None
binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
assert binding_before is not None
# Act: Execute the method under test
@ -929,14 +931,14 @@ class TestTagService:
# Assert: Verify the expected outcomes
# Verify tag was deleted
tag_after = db.session.query(Tag).where(Tag.id == tag.id).first()
tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first()
assert tag_after is None
# Verify tag binding was deleted
binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first()
assert binding_after is None
def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test tag deletion for non-existent tag.
@ -960,7 +962,7 @@ class TestTagService:
TagService.delete_tag(non_existent_tag_id)
assert "Tag not found" in str(exc_info.value)
def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful tag binding creation.
@ -988,12 +990,11 @@ class TestTagService:
TagService.save_tag_binding(binding_args)
# Assert: Verify the expected outcomes
from extensions.ext_database import db
# Verify tag bindings were created
for tag in tags:
binding = (
db.session.query(TagBinding)
db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
)
@ -1001,7 +1002,9 @@ class TestTagService:
assert binding.tenant_id == tenant.id
assert binding.created_by == account.id
def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_tag_binding_duplicate_handling(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag binding creation with duplicate bindings.
@ -1032,15 +1035,16 @@ class TestTagService:
TagService.save_tag_binding(binding_args)
# Assert: Verify the expected outcomes
from extensions.ext_database import db
# Verify only one binding exists
bindings = db.session.scalars(
bindings = db_session_with_containers.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 1
def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies):
def test_save_tag_binding_invalid_target_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag binding creation with invalid target type.
@ -1071,7 +1075,7 @@ class TestTagService:
TagService.save_tag_binding(binding_args)
assert "Invalid binding type" in str(exc_info.value)
def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful tag binding deletion.
@ -1098,10 +1102,11 @@ class TestTagService:
)
# Verify binding exists before deletion
from extensions.ext_database import db
binding_before = (
db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first()
db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
)
assert binding_before is not None
@ -1112,12 +1117,14 @@ class TestTagService:
# Assert: Verify the expected outcomes
# Verify tag binding was deleted
binding_after = (
db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first()
db_session_with_containers.query(TagBinding)
.where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id)
.first()
)
assert binding_after is None
def test_delete_tag_binding_non_existent_binding(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tag binding deletion for non-existent binding.
@ -1145,15 +1152,14 @@ class TestTagService:
# Assert: Verify the expected outcomes
# No error should be raised, and database state should remain unchanged
from extensions.ext_database import db
bindings = db.session.scalars(
bindings = db_session_with_containers.scalars(
select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id)
).all()
assert len(bindings) == 0
def test_check_target_exists_knowledge_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful target existence check for knowledge type.
@ -1179,7 +1185,7 @@ class TestTagService:
# No exception should be raised for existing dataset
def test_check_target_exists_knowledge_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test target existence check for non-existent knowledge dataset.
@ -1204,7 +1210,9 @@ class TestTagService:
TagService.check_target_exists("knowledge", non_existent_dataset_id)
assert "Dataset not found" in str(exc_info.value)
def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_check_target_exists_app_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful target existence check for app type.
@ -1228,7 +1236,9 @@ class TestTagService:
# Assert: Verify the expected outcomes
# No exception should be raised for existing app
def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_check_target_exists_app_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test target existence check for non-existent app.
@ -1252,7 +1262,9 @@ class TestTagService:
TagService.check_target_exists("app", non_existent_app_id)
assert "App not found" in str(exc_info.value)
def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies):
def test_check_target_exists_invalid_type(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test target existence check for invalid type.

View File

@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity
from extensions.ext_database import db
from models.provider_ids import TriggerProviderID
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
@ -47,7 +47,7 @@ class TestTriggerProviderService:
"account_feature_service": mock_account_feature_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -84,7 +84,7 @@ class TestTriggerProviderService:
def _create_test_subscription(
self,
db_session_with_containers,
db_session_with_containers: Session,
tenant_id,
user_id,
provider_id,
@ -135,14 +135,14 @@ class TestTriggerProviderService:
expires_at=-1,
)
db.session.add(subscription)
db.session.commit()
db.session.refresh(subscription)
db_session_with_containers.add(subscription)
db_session_with_containers.commit()
db_session_with_containers.refresh(subscription)
return subscription
def test_rebuild_trigger_subscription_success_with_merged_credentials(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful rebuild with credential merging (HIDDEN_VALUE handling).
@ -217,7 +217,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["api_secret"] == "new-secret-value" # New value
# Verify database state was updated
db.session.refresh(subscription)
db_session_with_containers.refresh(subscription)
assert subscription.name == "updated_name"
assert subscription.parameters == {"param1": "updated_value"}
@ -244,7 +244,7 @@ class TestTriggerProviderService:
)
def test_rebuild_trigger_subscription_with_all_new_credentials(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are new (no HIDDEN_VALUE).
@ -304,7 +304,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["api_secret"] == "completely-new-secret"
def test_rebuild_trigger_subscription_with_all_hidden_values(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing).
@ -363,7 +363,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["api_secret"] == original_credentials["api_secret"]
def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original.
@ -422,7 +422,7 @@ class TestTriggerProviderService:
assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE
def test_rebuild_trigger_subscription_rollback_on_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that transaction is rolled back on error.
@ -470,12 +470,12 @@ class TestTriggerProviderService:
)
# Verify subscription state was not changed (rolled back)
db.session.refresh(subscription)
db_session_with_containers.refresh(subscription)
assert subscription.name == original_name
assert subscription.parameters == original_parameters
def test_rebuild_trigger_subscription_subscription_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error when subscription is not found.
@ -501,7 +501,7 @@ class TestTriggerProviderService:
)
def test_rebuild_trigger_subscription_name_uniqueness_check(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that name uniqueness is checked when updating name.

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
@ -45,7 +46,7 @@ class TestWebConversationService:
"account_feature_service": mock_account_feature_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -90,7 +91,7 @@ class TestWebConversationService:
return app, account
def _create_test_end_user(self, db_session_with_containers, app):
def _create_test_end_user(self, db_session_with_containers: Session, app):
"""
Helper method to create a test end user for testing.
@ -111,14 +112,12 @@ class TestWebConversationService:
tenant_id=app.tenant_id,
)
from extensions.ext_database import db
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
return end_user
def _create_test_conversation(self, db_session_with_containers, app, user, fake):
def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake):
"""
Helper method to create a test conversation for testing.
@ -152,14 +151,14 @@ class TestWebConversationService:
is_deleted=False,
)
from extensions.ext_database import db
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
return conversation
def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_pagination_by_last_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination by last ID with basic parameters.
"""
@ -194,7 +193,7 @@ class TestWebConversationService:
assert result.data[1].updated_at >= result.data[2].updated_at
def test_pagination_by_last_id_with_pinned_filter(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by last ID with pinned conversation filter.
@ -222,11 +221,9 @@ class TestWebConversationService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(pinned_conversation1)
db.session.add(pinned_conversation2)
db.session.commit()
db_session_with_containers.add(pinned_conversation1)
db_session_with_containers.add(pinned_conversation2)
db_session_with_containers.commit()
# Test pagination with pinned filter
result = WebConversationService.pagination_by_last_id(
@ -251,7 +248,7 @@ class TestWebConversationService:
assert set(returned_ids) == set(expected_ids)
def test_pagination_by_last_id_with_unpinned_filter(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination by last ID with unpinned conversation filter.
@ -273,10 +270,8 @@ class TestWebConversationService:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(pinned_conversation)
db.session.commit()
db_session_with_containers.add(pinned_conversation)
db_session_with_containers.commit()
# Test pagination with unpinned filter
result = WebConversationService.pagination_by_last_id(
@ -303,7 +298,7 @@ class TestWebConversationService:
expected_unpinned_ids = [conv.id for conv in conversations[1:]]
assert set(returned_ids) == set(expected_unpinned_ids)
def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful pinning of a conversation.
"""
@ -317,10 +312,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account)
# Verify the conversation was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -336,7 +330,9 @@ class TestWebConversationService:
assert pinned_conversation.created_by_role == "account"
assert pinned_conversation.created_by == account.id
def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies):
def test_pin_conversation_already_pinned(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pinning a conversation that is already pinned (should not create duplicate).
"""
@ -353,9 +349,8 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account)
# Verify only one pinned conversation record exists
from extensions.ext_database import db
pinned_conversations = db.session.scalars(
pinned_conversations = db_session_with_containers.scalars(
select(PinnedConversation).where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -366,7 +361,9 @@ class TestWebConversationService:
assert len(pinned_conversations) == 1
def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies):
def test_pin_conversation_with_end_user(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pinning a conversation with an end user.
"""
@ -383,10 +380,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, end_user)
# Verify the conversation was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -402,7 +398,7 @@ class TestWebConversationService:
assert pinned_conversation.created_by_role == "end_user"
assert pinned_conversation.created_by == end_user.id
def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful unpinning of a conversation.
"""
@ -416,10 +412,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account)
# Verify it was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -436,7 +431,7 @@ class TestWebConversationService:
# Verify it was unpinned
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -448,7 +443,9 @@ class TestWebConversationService:
assert pinned_conversation is None
def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies):
def test_unpin_conversation_not_pinned(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test unpinning a conversation that is not pinned (should not cause error).
"""
@ -462,10 +459,9 @@ class TestWebConversationService:
WebConversationService.unpin(app, conversation.id, account)
# Verify no pinned conversation record exists
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -478,7 +474,7 @@ class TestWebConversationService:
assert pinned_conversation is None
def test_pagination_by_last_id_user_required_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that pagination_by_last_id raises ValueError when user is None.
@ -499,7 +495,7 @@ class TestWebConversationService:
sort_by="-updated_at",
)
def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test that pin method returns early when user is None.
"""
@ -513,10 +509,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, None)
# Verify no pinned conversation was created
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -526,7 +521,9 @@ class TestWebConversationService:
assert pinned_conversation is None
def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies):
def test_unpin_conversation_user_none(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that unpin method returns early when user is None.
"""
@ -540,10 +537,9 @@ class TestWebConversationService:
WebConversationService.pin(app, conversation.id, account)
# Verify it was pinned
from extensions.ext_database import db
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,
@ -560,7 +556,7 @@ class TestWebConversationService:
# Verify the conversation is still pinned
pinned_conversation = (
db.session.query(PinnedConversation)
db_session_with_containers.query(PinnedConversation)
.where(
PinnedConversation.app_id == app.id,
PinnedConversation.conversation_id == conversation.id,

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
from libs.password import hash_password
@ -45,7 +46,7 @@ class TestWebAppAuthService:
"enterprise_service": mock_enterprise_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -68,18 +69,16 @@ class TestWebAppAuthService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -88,15 +87,17 @@ class TestWebAppAuthService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_with_password(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Helper method to create a test account with password for testing.
@ -131,18 +132,16 @@ class TestWebAppAuthService:
account.password = base64.b64encode(password_hash).decode()
account.password_salt = base64.b64encode(salt).decode()
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -151,15 +150,17 @@ class TestWebAppAuthService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant, password
def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant):
def _create_test_app_and_site(
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant
):
"""
Helper method to create a test app and site for testing.
@ -188,10 +189,8 @@ class TestWebAppAuthService:
enable_api=True,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
# Create site
site = Site(
@ -203,12 +202,12 @@ class TestWebAppAuthService:
status="normal",
customize_token_strategy="not_allow",
)
db.session.add(site)
db.session.commit()
db_session_with_containers.add(site)
db_session_with_containers.commit()
return app, site
def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful authentication with valid email and password.
@ -233,14 +232,15 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.password is not None
assert result.password_salt is not None
def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_authenticate_account_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test authentication with non-existent email.
@ -262,7 +262,7 @@ class TestWebAppAuthService:
with pytest.raises(AccountNotFoundError):
WebAppAuthService.authenticate(non_existent_email, "any_password")
def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies):
def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test authentication with banned account.
@ -292,10 +292,8 @@ class TestWebAppAuthService:
account.password = base64.b64encode(password_hash).decode()
account.password_salt = base64.b64encode(salt).decode()
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(AccountLoginError) as exc_info:
@ -303,7 +301,9 @@ class TestWebAppAuthService:
assert "Account is banned." in str(exc_info.value)
def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies):
def test_authenticate_invalid_password(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test authentication with invalid password.
@ -323,7 +323,7 @@ class TestWebAppAuthService:
assert "Invalid email or password." in str(exc_info.value)
def test_authenticate_account_without_password(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test authentication for account without password.
@ -345,10 +345,8 @@ class TestWebAppAuthService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(AccountPasswordError) as exc_info:
@ -356,7 +354,7 @@ class TestWebAppAuthService:
assert "Invalid email or password." in str(exc_info.value)
def test_login_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful login and JWT token generation.
@ -388,7 +386,9 @@ class TestWebAppAuthService:
assert call_args["auth_type"] == "internal"
assert "exp" in call_args
def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_user_through_email_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful user retrieval through email.
@ -413,12 +413,13 @@ class TestWebAppAuthService:
assert result.status == AccountStatus.ACTIVE
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_user_through_email_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test user retrieval with non-existent email.
@ -435,7 +436,9 @@ class TestWebAppAuthService:
# Assert: Verify proper handling
assert result is None
def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_user_through_email_banned(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test user retrieval with banned account.
@ -456,10 +459,8 @@ class TestWebAppAuthService:
status=AccountStatus.BANNED,
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(Unauthorized) as exc_info:
@ -468,7 +469,7 @@ class TestWebAppAuthService:
assert "Account is banned." in str(exc_info.value)
def test_send_email_code_login_email_with_account(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test sending email code login email with account.
@ -509,7 +510,7 @@ class TestWebAppAuthService:
assert "code" in mail_call_args[1]
def test_send_email_code_login_email_with_email_only(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test sending email code login email with email only.
@ -549,7 +550,7 @@ class TestWebAppAuthService:
assert "code" in mail_call_args[1]
def test_send_email_code_login_email_no_email_provided(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test sending email code login email without providing email.
@ -566,7 +567,9 @@ class TestWebAppAuthService:
assert "Email must be provided." in str(exc_info.value)
def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_email_code_login_data_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of email code login data.
@ -593,7 +596,9 @@ class TestWebAppAuthService:
"mock_token", "email_code_login"
)
def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_email_code_login_data_no_data(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test email code login data retrieval when no data exists.
@ -617,7 +622,7 @@ class TestWebAppAuthService:
)
def test_revoke_email_code_login_token_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful revocation of email code login token.
@ -636,7 +641,7 @@ class TestWebAppAuthService:
"mock_token", "email_code_login"
)
def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful end user creation.
@ -668,14 +673,15 @@ class TestWebAppAuthService:
assert result.external_user_id == "enterpriseuser"
# Verify database state
from extensions.ext_database import db
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.created_at is not None
assert result.updated_at is not None
def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_end_user_site_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test end user creation with non-existent site code.
@ -693,7 +699,9 @@ class TestWebAppAuthService:
assert "Site not found." in str(exc_info.value)
def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_end_user_app_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test end user creation when app is not found.
@ -708,10 +716,8 @@ class TestWebAppAuthService:
status="normal",
)
from extensions.ext_database import db
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
site = Site(
app_id="00000000-0000-0000-0000-000000000000",
@ -722,8 +728,8 @@ class TestWebAppAuthService:
status="normal",
customize_token_strategy="not_allow",
)
db.session.add(site)
db.session.commit()
db_session_with_containers.add(site)
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling
with pytest.raises(NotFound) as exc_info:
@ -732,7 +738,7 @@ class TestWebAppAuthService:
assert "App not found." in str(exc_info.value)
def test_is_app_require_permission_check_with_access_mode_private(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test permission check requirement for private access mode.
@ -751,7 +757,7 @@ class TestWebAppAuthService:
assert result is True
def test_is_app_require_permission_check_with_access_mode_public(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test permission check requirement for public access mode.
@ -770,7 +776,7 @@ class TestWebAppAuthService:
assert result is False
def test_is_app_require_permission_check_with_app_code(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test permission check requirement using app code.
@ -796,7 +802,7 @@ class TestWebAppAuthService:
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id")
def test_is_app_require_permission_check_no_parameters(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test permission check requirement with no parameters.
@ -814,7 +820,7 @@ class TestWebAppAuthService:
assert "Either app_code or app_id must be provided." in str(exc_info.value)
def test_get_app_auth_type_with_access_mode_public(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app authentication type for public access mode.
@ -833,7 +839,7 @@ class TestWebAppAuthService:
assert result == WebAppAuthType.PUBLIC
def test_get_app_auth_type_with_access_mode_private(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app authentication type for private access mode.
@ -851,7 +857,9 @@ class TestWebAppAuthService:
# Assert: Verify correct result
assert result == WebAppAuthType.INTERNAL
def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_app_auth_type_with_app_code(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app authentication type using app code.
@ -878,7 +886,9 @@ class TestWebAppAuthService:
"enterprise_service"
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_app_auth_type_no_parameters(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test app authentication type with no parameters.

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.entities.workflow_execution import WorkflowExecutionStatus
from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun
@ -48,7 +49,7 @@ class TestWorkflowAppService:
"account_feature_service": mock_account_feature_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -96,7 +97,7 @@ class TestWorkflowAppService:
return app, account
def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test tenant and account for testing.
@ -126,7 +127,7 @@ class TestWorkflowAppService:
return tenant, account
def _create_test_app(self, db_session_with_containers, tenant, account):
def _create_test_app(self, db_session_with_containers: Session, tenant, account):
"""
Helper method to create a test app for testing.
@ -160,7 +161,7 @@ class TestWorkflowAppService:
return app
def _create_test_workflow_data(self, db_session_with_containers, app, account):
def _create_test_workflow_data(self, db_session_with_containers: Session, app, account):
"""
Helper method to create test workflow data for testing.
@ -174,8 +175,6 @@ class TestWorkflowAppService:
"""
fake = Faker()
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -188,8 +187,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create workflow run
workflow_run = WorkflowRun(
@ -212,8 +211,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
# Create workflow app log
workflow_app_log = WorkflowAppLog(
@ -227,13 +226,13 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
return workflow, workflow_run, workflow_app_log
def test_get_paginate_workflow_app_logs_basic_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination of workflow app logs with basic parameters.
@ -268,13 +267,12 @@ class TestWorkflowAppService:
assert log_entry.workflow_run_id == workflow_run.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(workflow_app_log)
db_session_with_containers.refresh(workflow_app_log)
assert workflow_app_log.id is not None
def test_get_paginate_workflow_app_logs_with_keyword_search(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with keyword search functionality.
@ -287,11 +285,10 @@ class TestWorkflowAppService:
)
# Update workflow run with searchable content
from extensions.ext_database import db
workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"})
workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"})
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the method under test with keyword search
service = WorkflowAppService()
@ -317,7 +314,7 @@ class TestWorkflowAppService:
assert len(result_no_match["data"]) == 0
def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
r"""
Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention.
@ -332,8 +329,6 @@ class TestWorkflowAppService:
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account)
from extensions.ext_database import db
service = WorkflowAppService()
# Test 1: Search with % character
@ -353,8 +348,8 @@ class TestWorkflowAppService:
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_1)
db.session.flush()
db_session_with_containers.add(workflow_run_1)
db_session_with_containers.flush()
workflow_app_log_1 = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -367,8 +362,8 @@ class TestWorkflowAppService:
)
workflow_app_log_1.id = str(uuid.uuid4())
workflow_app_log_1.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_1)
db.session.commit()
db_session_with_containers.add(workflow_app_log_1)
db_session_with_containers.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
@ -395,8 +390,8 @@ class TestWorkflowAppService:
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_2)
db.session.flush()
db_session_with_containers.add(workflow_run_2)
db_session_with_containers.flush()
workflow_app_log_2 = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -409,8 +404,8 @@ class TestWorkflowAppService:
)
workflow_app_log_2.id = str(uuid.uuid4())
workflow_app_log_2.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_2)
db.session.commit()
db_session_with_containers.add(workflow_app_log_2)
db_session_with_containers.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20
@ -437,8 +432,8 @@ class TestWorkflowAppService:
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(workflow_run_4)
db.session.flush()
db_session_with_containers.add(workflow_run_4)
db_session_with_containers.flush()
workflow_app_log_4 = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -451,8 +446,8 @@ class TestWorkflowAppService:
)
workflow_app_log_4.id = str(uuid.uuid4())
workflow_app_log_4.created_at = datetime.now(UTC)
db.session.add(workflow_app_log_4)
db.session.commit()
db_session_with_containers.add(workflow_app_log_4)
db_session_with_containers.commit()
result = service.get_paginate_workflow_app_logs(
session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20
@ -467,7 +462,7 @@ class TestWorkflowAppService:
assert workflow_run_4.id not in found_run_ids
def test_get_paginate_workflow_app_logs_with_status_filter(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with status filtering.
@ -476,8 +471,6 @@ class TestWorkflowAppService:
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -490,8 +483,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create workflow runs with different statuses
statuses = ["succeeded", "failed", "running", "stopped"]
@ -519,8 +512,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None,
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -533,8 +526,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log)
@ -568,7 +561,7 @@ class TestWorkflowAppService:
assert result_running["data"][0].workflow_run.status == "running"
def test_get_paginate_workflow_app_logs_with_time_filtering(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with time-based filtering.
@ -577,8 +570,6 @@ class TestWorkflowAppService:
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -591,8 +582,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create workflow runs with different timestamps
base_time = datetime.now(UTC)
@ -627,8 +618,8 @@ class TestWorkflowAppService:
created_at=timestamp,
finished_at=timestamp + timedelta(minutes=1),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -641,8 +632,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = timestamp
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log)
@ -682,7 +673,7 @@ class TestWorkflowAppService:
assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago
def test_get_paginate_workflow_app_logs_with_pagination(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with different page sizes and limits.
@ -691,8 +682,6 @@ class TestWorkflowAppService:
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -705,8 +694,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create 25 workflow runs and logs
total_logs = 25
@ -734,8 +723,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -748,8 +737,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log)
@ -798,7 +787,7 @@ class TestWorkflowAppService:
assert len(result_large_limit["data"]) == total_logs
def test_get_paginate_workflow_app_logs_with_user_role_filtering(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with user role and session filtering.
@ -807,8 +796,6 @@ class TestWorkflowAppService:
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -821,8 +808,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create end user
end_user = EndUser(
@ -835,8 +822,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
# Create workflow runs and logs for both account and end user
workflow_runs = []
@ -864,8 +851,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 1),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -878,8 +865,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log)
@ -906,8 +893,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC) + timedelta(minutes=i + 10),
finished_at=datetime.now(UTC) + timedelta(minutes=i + 11),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
workflow_app_log = WorkflowAppLog(
tenant_id=app.tenant_id,
@ -920,8 +907,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
workflow_runs.append(workflow_run)
workflow_app_logs.append(workflow_app_log)
@ -994,7 +981,7 @@ class TestWorkflowAppService:
assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_uuid_keyword_search(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with UUID keyword search functionality.
@ -1003,8 +990,6 @@ class TestWorkflowAppService:
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -1017,8 +1002,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create workflow run with specific UUID
workflow_run_id = str(uuid.uuid4())
@ -1042,8 +1027,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC),
finished_at=datetime.now(UTC) + timedelta(minutes=1),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
# Create workflow app log
workflow_app_log = WorkflowAppLog(
@ -1057,8 +1042,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
# Act & Assert: Test UUID keyword search
service = WorkflowAppService()
@ -1085,7 +1070,7 @@ class TestWorkflowAppService:
assert result_invalid_uuid["total"] == 0
def test_get_paginate_workflow_app_logs_with_edge_cases(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with edge cases and boundary conditions.
@ -1094,8 +1079,6 @@ class TestWorkflowAppService:
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
from extensions.ext_database import db
# Create workflow
workflow = Workflow(
id=str(uuid.uuid4()),
@ -1108,8 +1091,8 @@ class TestWorkflowAppService:
created_by=account.id,
updated_by=account.id,
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create workflow run with edge case data
workflow_run = WorkflowRun(
@ -1132,8 +1115,8 @@ class TestWorkflowAppService:
created_at=datetime.now(UTC),
finished_at=datetime.now(UTC),
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
# Create workflow app log
workflow_app_log = WorkflowAppLog(
@ -1147,8 +1130,8 @@ class TestWorkflowAppService:
)
workflow_app_log.id = str(uuid.uuid4())
workflow_app_log.created_at = datetime.now(UTC)
db.session.add(workflow_app_log)
db.session.commit()
db_session_with_containers.add(workflow_app_log)
db_session_with_containers.commit()
# Act & Assert: Test edge cases
service = WorkflowAppService()
@ -1185,7 +1168,7 @@ class TestWorkflowAppService:
assert result_high_page["has_more"] is False
def test_get_paginate_workflow_app_logs_with_empty_results(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with empty results and no data scenarios.
@ -1252,7 +1235,7 @@ class TestWorkflowAppService:
assert "Account not found" in str(exc_info.value)
def test_get_paginate_workflow_app_logs_with_complex_query_combinations(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with complex query combinations.
@ -1352,7 +1335,7 @@ class TestWorkflowAppService:
assert len(result_time_status_limit["data"]) <= 2
def test_get_paginate_workflow_app_logs_with_large_dataset_performance(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with large dataset for performance validation.
@ -1444,7 +1427,7 @@ class TestWorkflowAppService:
assert result_last_page["page"] == 3
def test_get_paginate_workflow_app_logs_with_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow app logs pagination with proper tenant isolation.

View File

@ -1,5 +1,6 @@
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from dify_graph.variables.segments import StringSegment
@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService:
# WorkflowDraftVariableService doesn't have external dependencies that need mocking
return {}
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None):
def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None):
"""
Helper method to create a test app with realistic data for testing.
@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService:
app.created_by = fake.uuid4()
app.updated_by = app.created_by
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def _create_test_workflow(self, db_session_with_containers, app, fake=None):
def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None):
"""
Helper method to create a test workflow associated with an app.
@ -110,15 +109,14 @@ class TestWorkflowDraftVariableService:
conversation_variables=[],
rag_pipeline_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
return workflow
def _create_test_variable(
self,
db_session_with_containers,
db_session_with_containers: Session,
app_id,
node_id,
name,
@ -174,13 +172,12 @@ class TestWorkflowDraftVariableService:
visible=True,
editable=True,
)
from extensions.ext_database import db
db.session.add(variable)
db.session.commit()
db_session_with_containers.add(variable)
db_session_with_containers.commit()
return variable
def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test getting a single variable by ID successfully.
@ -202,7 +199,7 @@ class TestWorkflowDraftVariableService:
assert retrieved_variable.app_id == app.id
assert retrieved_variable.get_value().value == test_value.value
def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test getting a variable that doesn't exist.
@ -217,7 +214,7 @@ class TestWorkflowDraftVariableService:
assert retrieved_variable is None
def test_get_draft_variables_by_selectors_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting variables by selectors successfully.
@ -268,7 +265,7 @@ class TestWorkflowDraftVariableService:
assert var.get_value().value == var3_value.value
def test_list_variables_without_values_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test listing variables without values successfully with pagination.
@ -300,7 +297,7 @@ class TestWorkflowDraftVariableService:
assert var.name is not None
assert var.app_id == app.id
def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test listing variables for a specific node successfully.
@ -352,7 +349,9 @@ class TestWorkflowDraftVariableService:
assert "var2" in var_names
assert "var3" not in var_names
def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_list_conversation_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test listing conversation variables successfully.
@ -393,7 +392,7 @@ class TestWorkflowDraftVariableService:
assert "conv_var2" in var_names
assert "sys_var" not in var_names
def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test updating a variable's name and value successfully.
@ -418,14 +417,15 @@ class TestWorkflowDraftVariableService:
assert updated_variable.name == "new_name"
assert updated_variable.get_value().value == new_value.value
assert updated_variable.last_edited_at is not None
from extensions.ext_database import db
db.session.refresh(variable)
db_session_with_containers.refresh(variable)
assert variable.name == "new_name"
assert variable.get_value().value == new_value.value
assert variable.last_edited_at is not None
def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_variable_not_editable(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that updating a non-editable variable raises an exception.
@ -445,17 +445,18 @@ class TestWorkflowDraftVariableService:
node_execution_id=fake.uuid4(),
editable=False, # Set as non-editable
)
from extensions.ext_database import db
db.session.add(variable)
db.session.commit()
db_session_with_containers.add(variable)
db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
with pytest.raises(UpdateNotSupportedError) as exc_info:
service.update_variable(variable, name="new_name", value=new_value)
assert "variable not support updating" in str(exc_info.value)
assert variable.id in str(exc_info.value)
def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_reset_conversation_variable_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test resetting conversation variable successfully.
@ -476,9 +477,8 @@ class TestWorkflowDraftVariableService:
selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"],
)
workflow.conversation_variables = [conv_var]
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
modified_value = StringSegment(value=fake.word())
variable = self._create_test_variable(
db_session_with_containers,
@ -489,17 +489,17 @@ class TestWorkflowDraftVariableService:
fake=fake,
)
variable.last_edited_at = fake.date_time()
db.session.commit()
db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
reset_variable = service.reset_variable(workflow, variable)
assert reset_variable is not None
assert reset_variable.get_value().value == "default_value"
assert reset_variable.last_edited_at is None
db.session.refresh(variable)
db_session_with_containers.refresh(variable)
assert variable.get_value().value == "default_value"
assert variable.last_edited_at is None
def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test deleting a single variable successfully.
@ -513,14 +513,15 @@ class TestWorkflowDraftVariableService:
variable = self._create_test_variable(
db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake
)
from extensions.ext_database import db
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_variable(variable)
assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None
def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_workflow_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test deleting all variables for a workflow successfully.
@ -550,20 +551,25 @@ class TestWorkflowDraftVariableService:
other_value,
fake=fake,
)
from extensions.ext_database import db
app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
)
assert len(app_variables) == 3
assert len(other_app_variables) == 1
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_workflow_variables(app.id)
app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all()
other_app_variables_after = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all()
)
assert len(app_variables_after) == 0
assert len(other_app_variables_after) == 1
def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_node_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test deleting all variables for a specific node successfully.
@ -605,14 +611,15 @@ class TestWorkflowDraftVariableService:
conv_value,
fake=fake,
)
from extensions.ext_database import db
target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
target_node_variables = (
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
)
other_node_variables = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
)
conv_variables = (
db.session.query(WorkflowDraftVariable)
db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
@ -622,13 +629,13 @@ class TestWorkflowDraftVariableService:
service = WorkflowDraftVariableService(db_session_with_containers)
service.delete_node_variables(app.id, node_id)
target_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all()
)
other_node_variables_after = (
db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all()
)
conv_variables_after = (
db.session.query(WorkflowDraftVariable)
db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
@ -637,7 +644,7 @@ class TestWorkflowDraftVariableService:
assert len(conv_variables_after) == 1
def test_prefill_conversation_variable_default_values_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test prefill conversation variable default values successfully.
@ -665,13 +672,12 @@ class TestWorkflowDraftVariableService:
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"],
)
workflow.conversation_variables = [conv_var1, conv_var2]
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
service = WorkflowDraftVariableService(db_session_with_containers)
service.prefill_conversation_variable_default_values(workflow)
draft_variables = (
db.session.query(WorkflowDraftVariable)
db_session_with_containers.query(WorkflowDraftVariable)
.filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID)
.all()
)
@ -686,7 +692,7 @@ class TestWorkflowDraftVariableService:
assert var.get_variable_type() == DraftVariableType.CONVERSATION
def test_get_conversation_id_from_draft_variable_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting conversation ID from draft variable successfully.
@ -713,7 +719,7 @@ class TestWorkflowDraftVariableService:
assert retrieved_conv_id == conversation_id
def test_get_conversation_id_from_draft_variable_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting conversation ID when it doesn't exist.
@ -728,7 +734,9 @@ class TestWorkflowDraftVariableService:
retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id)
assert retrieved_conv_id is None
def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_list_system_variables_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test listing system variables successfully.
@ -775,7 +783,9 @@ class TestWorkflowDraftVariableService:
assert "sys_var2" in var_names
assert "conv_var" not in var_names
def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_variable_by_name_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting variables by name successfully for different types.
@ -822,7 +832,9 @@ class TestWorkflowDraftVariableService:
assert retrieved_node_var.name == "test_node_var"
assert retrieved_node_var.node_id == "test_node"
def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_variable_by_name_not_found(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting variables by name when they don't exist.

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import CreatorUserRole
from models.model import (
@ -48,7 +49,7 @@ class TestWorkflowRunService:
"account_feature_service": mock_account_feature_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -94,7 +95,7 @@ class TestWorkflowRunService:
return app, account
def _create_test_workflow_run(
self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0
self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0
):
"""
Helper method to create a test workflow run for testing.
@ -110,8 +111,6 @@ class TestWorkflowRunService:
"""
fake = Faker()
from extensions.ext_database import db
# Create workflow run with offset timestamp
base_time = datetime.now(UTC)
created_time = base_time - timedelta(minutes=offset_minutes)
@ -136,12 +135,12 @@ class TestWorkflowRunService:
finished_at=created_time,
)
db.session.add(workflow_run)
db.session.commit()
db_session_with_containers.add(workflow_run)
db_session_with_containers.commit()
return workflow_run
def _create_test_message(self, db_session_with_containers, app, account, workflow_run):
def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run):
"""
Helper method to create a test message for testing.
@ -156,8 +155,6 @@ class TestWorkflowRunService:
"""
fake = Faker()
from extensions.ext_database import db
# Create conversation first (required for message)
from models.model import Conversation
@ -170,8 +167,8 @@ class TestWorkflowRunService:
from_source=CreatorUserRole.ACCOUNT,
from_account_id=account.id,
)
db.session.add(conversation)
db.session.commit()
db_session_with_containers.add(conversation)
db_session_with_containers.commit()
# Create message
message = Message()
@ -193,12 +190,14 @@ class TestWorkflowRunService:
message.workflow_run_id = workflow_run.id
message.inputs = {"input": "test input"}
db.session.add(message)
db.session.commit()
db_session_with_containers.add(message)
db_session_with_containers.commit()
return message
def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_paginate_workflow_runs_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination of workflow runs with debugging trigger.
@ -239,7 +238,7 @@ class TestWorkflowRunService:
assert workflow_run.tenant_id == app.tenant_id
def test_get_paginate_workflow_runs_with_last_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination of workflow runs with last_id parameter.
@ -282,7 +281,7 @@ class TestWorkflowRunService:
assert workflow_run.tenant_id == app.tenant_id
def test_get_paginate_workflow_runs_default_limit(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test pagination of workflow runs with default limit.
@ -320,7 +319,7 @@ class TestWorkflowRunService:
assert workflow_run_result.tenant_id == app.tenant_id
def test_get_paginate_advanced_chat_workflow_runs_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful pagination of advanced chat workflow runs with message information.
@ -365,7 +364,7 @@ class TestWorkflowRunService:
assert workflow_run.app_id == app.id
assert workflow_run.tenant_id == app.tenant_id
def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of workflow run by ID.
@ -395,7 +394,7 @@ class TestWorkflowRunService:
assert result.type == "chat"
assert result.version == "1.0.0"
def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test workflow run retrieval when run ID does not exist.
@ -419,7 +418,7 @@ class TestWorkflowRunService:
assert result is None
def test_get_workflow_run_node_executions_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of workflow run node executions.
@ -438,7 +437,6 @@ class TestWorkflowRunService:
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Create node executions
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionModel
node_executions = []
@ -462,7 +460,7 @@ class TestWorkflowRunService:
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(node_execution)
db_session_with_containers.add(node_execution)
node_executions.append(node_execution)
paused_node_execution = WorkflowNodeExecutionModel(
@ -484,9 +482,9 @@ class TestWorkflowRunService:
created_by=account.id,
created_at=datetime.now(UTC),
)
db.session.add(paused_node_execution)
db_session_with_containers.add(paused_node_execution)
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()
@ -509,7 +507,7 @@ class TestWorkflowRunService:
assert node_execution.node_id.startswith("node_")
def test_get_workflow_run_node_executions_empty(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting node executions for a workflow run with no executions.
@ -560,7 +558,7 @@ class TestWorkflowRunService:
assert len(result) == 0
def test_get_workflow_run_node_executions_invalid_workflow_run_id(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting node executions with invalid workflow run ID.
@ -611,7 +609,7 @@ class TestWorkflowRunService:
assert len(result) == 0
def test_get_workflow_run_node_executions_database_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test getting node executions when database encounters an error.
@ -662,7 +660,7 @@ class TestWorkflowRunService:
)
def test_get_workflow_run_node_executions_end_user(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test node execution retrieval for end user.
@ -680,7 +678,6 @@ class TestWorkflowRunService:
workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging")
# Create end user
from extensions.ext_database import db
from models.model import EndUser
end_user = EndUser(
@ -692,8 +689,8 @@ class TestWorkflowRunService:
external_user_id=str(uuid.uuid4()),
name=fake.name(),
)
db.session.add(end_user)
db.session.commit()
db_session_with_containers.add(end_user)
db_session_with_containers.commit()
# Create node execution
from models.workflow import WorkflowNodeExecutionModel
@ -717,8 +714,8 @@ class TestWorkflowRunService:
created_by=end_user.id,
created_at=datetime.now(UTC),
)
db.session.add(node_execution)
db.session.commit()
db_session_with_containers.add(node_execution)
db_session_with_containers.commit()
# Act: Execute the method under test
workflow_run_service = WorkflowRunService()

View File

@ -10,6 +10,7 @@ from unittest.mock import MagicMock
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import Account, App, Workflow
from models.model import AppMode
@ -32,7 +33,7 @@ class TestWorkflowService:
and realistic testing environment with actual database interactions.
"""
def _create_test_account(self, db_session_with_containers, fake=None):
def _create_test_account(self, db_session_with_containers: Session, fake=None):
"""
Helper method to create a test account with realistic data.
@ -67,18 +68,16 @@ class TestWorkflowService:
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
from extensions.ext_database import db
db.session.add(tenant)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Set the current tenant for the account
account.current_tenant = tenant
return account
def _create_test_app(self, db_session_with_containers, fake=None):
def _create_test_app(self, db_session_with_containers: Session, fake=None):
"""
Helper method to create a test app with realistic data.
@ -106,13 +105,11 @@ class TestWorkflowService:
)
app.updated_by = app.created_by
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
return app
def _create_test_workflow(self, db_session_with_containers, app, account, fake=None):
def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None):
"""
Helper method to create a test workflow associated with an app.
@ -141,13 +138,11 @@ class TestWorkflowService:
conversation_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
return workflow
def test_get_node_last_run_success(self, db_session_with_containers):
def test_get_node_last_run_success(self, db_session_with_containers: Session):
"""
Test successful retrieval of the most recent execution for a specific node.
@ -180,10 +175,8 @@ class TestWorkflowService:
node_execution.created_by = account.id # Required field
node_execution.created_at = fake.date_time_this_year()
from extensions.ext_database import db
db.session.add(node_execution)
db.session.commit()
db_session_with_containers.add(node_execution)
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -196,7 +189,7 @@ class TestWorkflowService:
assert result.workflow_id == workflow.id
assert result.status == "succeeded"
def test_get_node_last_run_not_found(self, db_session_with_containers):
def test_get_node_last_run_not_found(self, db_session_with_containers: Session):
"""
Test retrieval when no execution record exists for the specified node.
@ -217,7 +210,7 @@ class TestWorkflowService:
# Assert
assert result is None
def test_is_workflow_exist_true(self, db_session_with_containers):
def test_is_workflow_exist_true(self, db_session_with_containers: Session):
"""
Test workflow existence check when a draft workflow exists.
@ -238,7 +231,7 @@ class TestWorkflowService:
# Assert
assert result is True
def test_is_workflow_exist_false(self, db_session_with_containers):
def test_is_workflow_exist_false(self, db_session_with_containers: Session):
"""
Test workflow existence check when no draft workflow exists.
@ -258,7 +251,7 @@ class TestWorkflowService:
# Assert
assert result is False
def test_get_draft_workflow_success(self, db_session_with_containers):
def test_get_draft_workflow_success(self, db_session_with_containers: Session):
"""
Test successful retrieval of a draft workflow.
@ -284,7 +277,7 @@ class TestWorkflowService:
assert result.app_id == app.id
assert result.tenant_id == app.tenant_id
def test_get_draft_workflow_not_found(self, db_session_with_containers):
def test_get_draft_workflow_not_found(self, db_session_with_containers: Session):
"""
Test draft workflow retrieval when no draft workflow exists.
@ -304,7 +297,7 @@ class TestWorkflowService:
# Assert
assert result is None
def test_get_published_workflow_by_id_success(self, db_session_with_containers):
def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session):
"""
Test successful retrieval of a published workflow by ID.
@ -321,9 +314,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Published version
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -336,7 +327,7 @@ class TestWorkflowService:
assert result.version != Workflow.VERSION_DRAFT
assert result.app_id == app.id
def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers):
def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session):
"""
Test error when trying to retrieve a draft workflow as published.
@ -359,7 +350,7 @@ class TestWorkflowService:
with pytest.raises(IsDraftWorkflowError):
workflow_service.get_published_workflow_by_id(app, workflow.id)
def test_get_published_workflow_by_id_not_found(self, db_session_with_containers):
def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session):
"""
Test retrieval when no workflow exists with the specified ID.
@ -379,7 +370,7 @@ class TestWorkflowService:
# Assert
assert result is None
def test_get_published_workflow_success(self, db_session_with_containers):
def test_get_published_workflow_success(self, db_session_with_containers: Session):
"""
Test successful retrieval of the current published workflow for an app.
@ -395,10 +386,8 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Published version
from extensions.ext_database import db
app.workflow_id = workflow.id
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -411,7 +400,7 @@ class TestWorkflowService:
assert result.version != Workflow.VERSION_DRAFT
assert result.app_id == app.id
def test_get_published_workflow_no_workflow_id(self, db_session_with_containers):
def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session):
"""
Test retrieval when app has no associated workflow ID.
@ -431,7 +420,7 @@ class TestWorkflowService:
# Assert
assert result is None
def test_get_all_published_workflow_pagination(self, db_session_with_containers):
def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session):
"""
Test pagination of published workflows.
@ -455,15 +444,13 @@ class TestWorkflowService:
# Set the app's workflow_id to the first workflow
app.workflow_id = workflows[0].id
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
# Act - First page
result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session,
session=db_session_with_containers,
app_model=app,
page=1,
limit=3,
@ -476,7 +463,7 @@ class TestWorkflowService:
# Act - Second page
result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session,
session=db_session_with_containers,
app_model=app,
page=2,
limit=3,
@ -487,7 +474,7 @@ class TestWorkflowService:
assert len(result_workflows) == 2
assert has_more is False
def test_get_all_published_workflow_user_filter(self, db_session_with_containers):
def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session):
"""
Test filtering published workflows by user.
@ -513,22 +500,20 @@ class TestWorkflowService:
# Set the app's workflow_id to the first workflow
app.workflow_id = workflow1.id
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
# Act - Filter by account1
result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session, app_model=app, page=1, limit=10, user_id=account1.id
session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id
)
# Assert
assert len(result_workflows) == 1
assert result_workflows[0].created_by == account1.id
def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers):
def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session):
"""
Test filtering published workflows to show only named workflows.
@ -557,22 +542,20 @@ class TestWorkflowService:
# Set the app's workflow_id to the first workflow
app.workflow_id = workflow1.id
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
# Act - Filter named only
result_workflows, has_more = workflow_service.get_all_published_workflow(
session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True
session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True
)
# Assert
assert len(result_workflows) == 2
assert all(wf.marked_name for wf in result_workflows)
def test_sync_draft_workflow_create_new(self, db_session_with_containers):
def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session):
"""
Test creating a new draft workflow through sync operation.
@ -624,7 +607,7 @@ class TestWorkflowService:
assert result.features == json.dumps(features)
assert result.created_by == account.id
def test_sync_draft_workflow_update_existing(self, db_session_with_containers):
def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session):
"""
Test updating an existing draft workflow through sync operation.
@ -688,7 +671,7 @@ class TestWorkflowService:
assert result.features == json.dumps(new_features)
assert result.updated_by == account.id
def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers):
def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session):
"""
Test error when sync is attempted with mismatched hash.
@ -738,7 +721,7 @@ class TestWorkflowService:
conversation_variables=conversation_variables,
)
def test_publish_workflow_success(self, db_session_with_containers):
def test_publish_workflow_success(self, db_session_with_containers: Session):
"""
Test successful workflow publishing.
@ -755,9 +738,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = Workflow.VERSION_DRAFT
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -777,7 +758,7 @@ class TestWorkflowService:
assert len(result.version) > 10 # Should be a reasonable timestamp length
assert result.created_by == account.id
def test_publish_workflow_no_draft_error(self, db_session_with_containers):
def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session):
"""
Test error when publishing workflow without draft.
@ -797,7 +778,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="No valid workflow found"):
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
def test_publish_workflow_already_published_error(self, db_session_with_containers):
def test_publish_workflow_already_published_error(self, db_session_with_containers: Session):
"""
Test error when publishing already published workflow.
@ -813,9 +794,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Already published
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -823,7 +802,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="No valid workflow found"):
workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account)
def test_get_default_block_configs(self, db_session_with_containers):
def test_get_default_block_configs(self, db_session_with_containers: Session):
"""
Test retrieval of default block configurations for all node types.
@ -847,7 +826,7 @@ class TestWorkflowService:
assert isinstance(config, dict)
# The structure can vary, so we just check it's a dict
def test_get_default_block_config_specific_type(self, db_session_with_containers):
def test_get_default_block_config_specific_type(self, db_session_with_containers: Session):
"""
Test retrieval of default block configuration for a specific node type.
@ -867,7 +846,7 @@ class TestWorkflowService:
# This is acceptable behavior
assert result is None or isinstance(result, dict)
def test_get_default_block_config_invalid_type(self, db_session_with_containers):
def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session):
"""
Test retrieval of default block configuration for invalid node type.
@ -887,7 +866,7 @@ class TestWorkflowService:
# It's also acceptable for the service to raise a ValueError for invalid types
pass
def test_get_default_block_config_with_filters(self, db_session_with_containers):
def test_get_default_block_config_with_filters(self, db_session_with_containers: Session):
"""
Test retrieval of default block configuration with filters.
@ -907,7 +886,7 @@ class TestWorkflowService:
# Result might be None if filters don't match, but should not raise error
assert result is None or isinstance(result, dict)
def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers):
def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session):
"""
Test successful conversion from chat mode app to workflow mode.
@ -944,11 +923,9 @@ class TestWorkflowService:
)
app_model_config.id = fake.uuid4()
from extensions.ext_database import db
db.session.add(app_model_config)
db_session_with_containers.add(app_model_config)
app.app_model_config_id = app_model_config.id
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
conversion_args = {
@ -969,7 +946,7 @@ class TestWorkflowService:
assert result.icon_type == conversion_args["icon_type"]
assert result.icon_background == conversion_args["icon_background"]
def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers):
def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session):
"""
Test successful conversion from completion mode app to workflow mode.
@ -1006,11 +983,9 @@ class TestWorkflowService:
)
app_model_config.id = fake.uuid4()
from extensions.ext_database import db
db.session.add(app_model_config)
db_session_with_containers.add(app_model_config)
app.app_model_config_id = app_model_config.id
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
conversion_args = {
@ -1031,7 +1006,7 @@ class TestWorkflowService:
assert result.icon_type == conversion_args["icon_type"]
assert result.icon_background == conversion_args["icon_background"]
def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers):
def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session):
"""
Test error when attempting to convert unsupported app mode.
@ -1046,9 +1021,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
conversion_args = {"name": "Test"}
@ -1057,7 +1030,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"):
workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args)
def test_validate_features_structure_advanced_chat(self, db_session_with_containers):
def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session):
"""
Test feature structure validation for advanced chat mode apps.
@ -1069,9 +1042,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.ADVANCED_CHAT
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
features = {
@ -1088,7 +1059,7 @@ class TestWorkflowService:
# The exact behavior depends on the AdvancedChatAppConfigManager implementation
assert result is not None or isinstance(result, dict)
def test_validate_features_structure_workflow(self, db_session_with_containers):
def test_validate_features_structure_workflow(self, db_session_with_containers: Session):
"""
Test feature structure validation for workflow mode apps.
@ -1100,9 +1071,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake)
app.mode = AppMode.WORKFLOW
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
features = {"workflow_config": {"max_steps": 10, "timeout": 300}}
@ -1115,7 +1084,7 @@ class TestWorkflowService:
# The exact behavior depends on the WorkflowAppConfigManager implementation
assert result is not None or isinstance(result, dict)
def test_validate_features_structure_invalid_mode(self, db_session_with_containers):
def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session):
"""
Test error when validating features for invalid app mode.
@ -1127,9 +1096,7 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake)
app.mode = "invalid_mode" # Invalid mode
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
features = {"test": "value"}
@ -1138,7 +1105,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"):
workflow_service.validate_features_structure(app_model=app, features=features)
def test_update_workflow_success(self, db_session_with_containers):
def test_update_workflow_success(self, db_session_with_containers: Session):
"""
Test successful workflow update with allowed fields.
@ -1152,16 +1119,14 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake)
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"}
# Act
result = workflow_service.update_workflow(
session=db.session,
session=db_session_with_containers,
workflow_id=workflow.id,
tenant_id=workflow.tenant_id,
account_id=account.id,
@ -1174,7 +1139,7 @@ class TestWorkflowService:
assert result.marked_comment == update_data["marked_comment"]
assert result.updated_by == account.id
def test_update_workflow_not_found(self, db_session_with_containers):
def test_update_workflow_not_found(self, db_session_with_containers: Session):
"""
Test workflow update when workflow doesn't exist.
@ -1186,15 +1151,13 @@ class TestWorkflowService:
account = self._create_test_account(db_session_with_containers, fake)
app = self._create_test_app(db_session_with_containers, fake)
from extensions.ext_database import db
workflow_service = WorkflowService()
non_existent_workflow_id = fake.uuid4()
update_data = {"marked_name": "Test"}
# Act
result = workflow_service.update_workflow(
session=db.session,
session=db_session_with_containers,
workflow_id=non_existent_workflow_id,
tenant_id=app.tenant_id,
account_id=account.id,
@ -1204,7 +1167,7 @@ class TestWorkflowService:
# Assert
assert result is None
def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers):
def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session):
"""
Test that workflow update ignores disallowed fields.
@ -1218,9 +1181,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
original_name = workflow.marked_name
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
update_data = {
@ -1231,7 +1192,7 @@ class TestWorkflowService:
# Act
result = workflow_service.update_workflow(
session=db.session,
session=db_session_with_containers,
workflow_id=workflow.id,
tenant_id=workflow.tenant_id,
account_id=account.id,
@ -1245,7 +1206,7 @@ class TestWorkflowService:
assert result.graph == workflow.graph
assert result.features == workflow.features
def test_delete_workflow_success(self, db_session_with_containers):
def test_delete_workflow_success(self, db_session_with_containers: Session):
"""
Test successful workflow deletion.
@ -1262,25 +1223,23 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
workflow.version = "2024.01.01.001" # Published version
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
# Act
result = workflow_service.delete_workflow(
session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
)
# Assert
assert result is True
# Verify workflow is actually deleted
deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first()
deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first()
assert deleted_workflow is None
def test_delete_workflow_draft_error(self, db_session_with_containers):
def test_delete_workflow_draft_error(self, db_session_with_containers: Session):
"""
Test error when attempting to delete a draft workflow.
@ -1296,9 +1255,7 @@ class TestWorkflowService:
workflow = self._create_test_workflow(db_session_with_containers, app, account, fake)
# Keep as draft version
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -1306,9 +1263,11 @@ class TestWorkflowService:
from services.errors.workflow_service import DraftWorkflowDeletionError
with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"):
workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id)
workflow_service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
)
def test_delete_workflow_in_use_error(self, db_session_with_containers):
def test_delete_workflow_in_use_error(self, db_session_with_containers: Session):
"""
Test error when attempting to delete a workflow that's in use by an app.
@ -1327,9 +1286,7 @@ class TestWorkflowService:
# Associate workflow with app
app.workflow_id = workflow.id
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
workflow_service = WorkflowService()
@ -1337,9 +1294,11 @@ class TestWorkflowService:
from services.errors.workflow_service import WorkflowInUseError
with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"):
workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id)
workflow_service.delete_workflow(
session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id
)
def test_delete_workflow_not_found_error(self, db_session_with_containers):
def test_delete_workflow_not_found_error(self, db_session_with_containers: Session):
"""
Test error when attempting to delete a non-existent workflow.
@ -1351,17 +1310,15 @@ class TestWorkflowService:
app = self._create_test_app(db_session_with_containers, fake)
non_existent_workflow_id = fake.uuid4()
from extensions.ext_database import db
workflow_service = WorkflowService()
# Act & Assert
with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"):
workflow_service.delete_workflow(
session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id
session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id
)
def test_run_free_workflow_node_success(self, db_session_with_containers):
def test_run_free_workflow_node_success(self, db_session_with_containers: Session):
"""
Test successful execution of a free workflow node.
@ -1413,7 +1370,7 @@ class TestWorkflowService:
assert result.workflow_id == "" # No workflow ID for free nodes
assert result.index == 1
def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers):
def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session):
"""
Test execution of a free workflow node with complex input data.
@ -1454,7 +1411,7 @@ class TestWorkflowService:
error_msg = str(exc_info.value).lower()
assert any(keyword in error_msg for keyword in ["start", "not supported", "external"])
def test_handle_node_run_result_success(self, db_session_with_containers):
def test_handle_node_run_result_success(self, db_session_with_containers: Session):
"""
Test successful handling of node run results.
@ -1529,7 +1486,7 @@ class TestWorkflowService:
assert result.outputs is not None
assert result.process_data is not None
def test_handle_node_run_result_failure(self, db_session_with_containers):
def test_handle_node_run_result_failure(self, db_session_with_containers: Session):
"""
Test handling of failed node run results.
@ -1598,7 +1555,7 @@ class TestWorkflowService:
assert result.error is not None
assert "Test error message" in str(result.error)
def test_handle_node_run_result_continue_on_error(self, db_session_with_containers):
def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session):
"""
Test handling of node run results with continue_on_error strategy.

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from services.workspace_service import WorkspaceService
@ -29,7 +30,7 @@ class TestWorkspaceService:
"dify_config": mock_dify_config,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -50,10 +51,8 @@ class TestWorkspaceService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -62,8 +61,8 @@ class TestWorkspaceService:
plan="basic",
custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}',
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
@ -72,15 +71,15 @@ class TestWorkspaceService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of tenant information with all features enabled.
@ -121,13 +120,12 @@ class TestWorkspaceService:
assert "replace_webapp_logo" in result["custom_config"]
# Verify database state
from extensions.ext_database import db
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_without_custom_config(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval when custom config features are disabled.
@ -167,13 +165,12 @@ class TestWorkspaceService:
assert "custom_config" not in result
# Verify database state
from extensions.ext_database import db
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_normal_user_role(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval for normal user role without privileged features.
@ -191,11 +188,14 @@ class TestWorkspaceService:
)
# Update the join to have normal role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.NORMAL
db.session.commit()
db_session_with_containers.commit()
# Setup mocks for feature service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -220,11 +220,11 @@ class TestWorkspaceService:
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_admin_role_and_logo_replacement(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval for admin role with logo replacement enabled.
@ -242,11 +242,14 @@ class TestWorkspaceService:
)
# Update the join to have admin role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.ADMIN
db.session.commit()
db_session_with_containers.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -268,10 +271,12 @@ class TestWorkspaceService:
assert "replace_webapp_logo" in result["custom_config"]
# Verify database state
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_tenant_info_with_tenant_none(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval when tenant parameter is None.
@ -290,7 +295,7 @@ class TestWorkspaceService:
assert result is None
def test_get_tenant_info_with_custom_config_variations(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval with various custom config configurations.
@ -323,10 +328,8 @@ class TestWorkspaceService:
# Update tenant custom config
import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config)
db.session.commit()
db_session_with_containers.commit()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -353,11 +356,11 @@ class TestWorkspaceService:
assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"]
# Verify database state
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_editor_role_and_limited_permissions(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval for editor role with limited permissions.
@ -375,11 +378,14 @@ class TestWorkspaceService:
)
# Update the join to have editor role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.EDITOR
db.session.commit()
db_session_with_containers.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -400,11 +406,11 @@ class TestWorkspaceService:
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_dataset_operator_role(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval for dataset operator role.
@ -422,11 +428,14 @@ class TestWorkspaceService:
)
# Update the join to have dataset operator role
from extensions.ext_database import db
join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
join = (
db_session_with_containers.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
join.role = TenantAccountRole.DATASET_OPERATOR
db.session.commit()
db_session_with_containers.commit()
# Setup mocks for feature service and tenant service
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -447,11 +456,11 @@ class TestWorkspaceService:
assert "custom_config" not in result
# Verify database state
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_with_complex_custom_config_scenarios(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant info retrieval with complex custom config scenarios.
@ -491,10 +500,8 @@ class TestWorkspaceService:
# Update tenant custom config
import json
from extensions.ext_database import db
tenant.custom_config = json.dumps(config)
db.session.commit()
db_session_with_containers.commit()
# Setup mocks
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
@ -525,5 +532,5 @@ class TestWorkspaceService:
assert result["custom_config"]["remove_webapp_brand"] is False
# Verify database state
db.session.refresh(tenant)
db_session_with_containers.refresh(tenant)
assert tenant.id is not None

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import TypeAdapter, ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ApiProviderSchemaType
from models import Account, Tenant
@ -34,7 +35,7 @@ class TestApiToolManageService:
"provider_controller": mock_provider_controller,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -55,18 +56,16 @@ class TestApiToolManageService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
@ -77,8 +76,8 @@ class TestApiToolManageService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -118,7 +117,7 @@ class TestApiToolManageService:
"""
def test_parser_api_schema_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful parsing of API schema.
@ -163,7 +162,7 @@ class TestApiToolManageService:
assert api_key_value_field["default"] == ""
def test_parser_api_schema_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test parsing of invalid API schema.
@ -183,7 +182,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_parser_api_schema_malformed_json(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test parsing of malformed JSON schema.
@ -203,7 +202,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_convert_schema_to_tool_bundles_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles.
@ -233,7 +232,7 @@ class TestApiToolManageService:
assert tool_bundle.operation_id == "testOperation"
def test_convert_schema_to_tool_bundles_with_extra_info(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of schema to tool bundles with extra info.
@ -259,7 +258,7 @@ class TestApiToolManageService:
assert isinstance(schema_type, str)
def test_convert_schema_to_tool_bundles_invalid_schema(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of invalid schema to tool bundles.
@ -279,7 +278,7 @@ class TestApiToolManageService:
assert "invalid schema" in str(exc_info.value)
def test_create_api_tool_provider_success(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider.
@ -324,10 +323,9 @@ class TestApiToolManageService:
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)
@ -347,7 +345,7 @@ class TestApiToolManageService:
mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once()
def test_create_api_tool_provider_duplicate_name(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creation of API tool provider with duplicate name.
@ -404,7 +402,7 @@ class TestApiToolManageService:
assert f"provider {provider_name} already exists" in str(exc_info.value)
def test_create_api_tool_provider_invalid_schema_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creation of API tool provider with invalid schema type.
@ -436,7 +434,7 @@ class TestApiToolManageService:
assert "validation error" in str(exc_info.value)
def test_create_api_tool_provider_missing_auth_type(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test creation of API tool provider with missing auth type.
@ -479,7 +477,7 @@ class TestApiToolManageService:
assert "auth_type is required" in str(exc_info.value)
def test_create_api_tool_provider_with_api_key_auth(
self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies
self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful creation of API tool provider with API key authentication.
@ -522,10 +520,9 @@ class TestApiToolManageService:
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
provider = (
db.session.query(ApiToolProvider)
db_session_with_containers.query(ApiToolProvider)
.filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name)
.first()
)

View File

@ -2,6 +2,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import ToolProviderType
from models import Account, Tenant
@ -41,7 +42,7 @@ class TestMCPToolManageService:
"tool_transform_service": mock_tool_transform_service,
}
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -62,18 +63,16 @@ class TestMCPToolManageService:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
@ -84,8 +83,8 @@ class TestMCPToolManageService:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -93,7 +92,7 @@ class TestMCPToolManageService:
return account, tenant
def _create_test_mcp_provider(
self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id
):
"""
Helper method to create a test MCP tool provider for testing.
@ -124,15 +123,13 @@ class TestMCPToolManageService:
sse_read_timeout=300.0,
)
from extensions.ext_database import db
db.session.add(mcp_provider)
db.session.commit()
db_session_with_containers.add(mcp_provider)
db_session_with_containers.commit()
return mcp_provider
def test_get_mcp_provider_by_provider_id_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of MCP provider by provider ID.
@ -153,9 +150,8 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
@ -166,12 +162,12 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.server_identifier == mcp_provider.server_identifier
def test_get_mcp_provider_by_provider_id_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP provider is not found by provider ID.
@ -190,14 +186,13 @@ class TestMCPToolManageService:
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id)
def test_get_mcp_provider_by_provider_id_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when retrieving MCP provider by provider ID.
@ -223,14 +218,13 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id)
def test_get_mcp_provider_by_server_identifier_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful retrieval of MCP provider by server identifier.
@ -251,9 +245,8 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id)
# Assert: Verify the expected outcomes
@ -264,12 +257,12 @@ class TestMCPToolManageService:
assert result.user_id == account.id
# Verify database state
db.session.refresh(result)
db_session_with_containers.refresh(result)
assert result.id is not None
assert result.name == mcp_provider.name
def test_get_mcp_provider_by_server_identifier_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP provider is not found by server identifier.
@ -288,14 +281,13 @@ class TestMCPToolManageService:
non_existent_identifier = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id)
def test_get_mcp_provider_by_server_identifier_tenant_isolation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when retrieving MCP provider by server identifier.
@ -321,13 +313,12 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id)
def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful creation of MCP provider.
@ -365,9 +356,8 @@ class TestMCPToolManageService:
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
@ -389,10 +379,9 @@ class TestMCPToolManageService:
assert result.type == ToolProviderType.MCP
# Verify database state
from extensions.ext_database import db
created_provider = (
db.session.query(MCPToolProvider)
db_session_with_containers.query(MCPToolProvider)
.filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider")
.first()
)
@ -410,7 +399,9 @@ class TestMCPToolManageService:
)
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once()
def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_mcp_provider_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when creating MCP provider with duplicate name.
@ -427,9 +418,8 @@ class TestMCPToolManageService:
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider",
@ -463,7 +453,7 @@ class TestMCPToolManageService:
)
def test_create_mcp_provider_duplicate_server_url(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when creating MCP provider with duplicate server URL.
@ -481,9 +471,8 @@ class TestMCPToolManageService:
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
@ -517,7 +506,7 @@ class TestMCPToolManageService:
)
def test_create_mcp_provider_duplicate_server_identifier(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when creating MCP provider with duplicate server identifier.
@ -535,9 +524,8 @@ class TestMCPToolManageService:
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.create_provider(
tenant_id=tenant.id,
name="Test MCP Provider 1",
@ -570,7 +558,7 @@ class TestMCPToolManageService:
),
)
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful retrieval of MCP tools for a tenant.
@ -602,9 +590,7 @@ class TestMCPToolManageService:
)
provider3.name = "Gamma Provider"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Setup mock for transformation service
from core.tools.entities.api_entities import ToolProviderApiEntity
@ -647,9 +633,8 @@ class TestMCPToolManageService:
]
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.list_providers(tenant_id=tenant.id, for_list=True)
# Assert: Verify the expected outcomes
@ -666,7 +651,9 @@ class TestMCPToolManageService:
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3
)
def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies):
def test_retrieve_mcp_tools_empty_list(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test retrieval of MCP tools when tenant has no providers.
@ -684,9 +671,8 @@ class TestMCPToolManageService:
# No MCP providers created for this tenant
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.list_providers(tenant_id=tenant.id, for_list=False)
# Assert: Verify the expected outcomes
@ -697,7 +683,9 @@ class TestMCPToolManageService:
# Verify no transformation service calls for empty list
mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called()
def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
def test_retrieve_mcp_tools_tenant_isolation(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when retrieving MCP tools.
@ -756,9 +744,8 @@ class TestMCPToolManageService:
]
# Act: Execute the method under test for both tenants
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result1 = service.list_providers(tenant_id=tenant1.id, for_list=True)
result2 = service.list_providers(tenant_id=tenant2.id, for_list=True)
@ -769,7 +756,7 @@ class TestMCPToolManageService:
assert result2[0].id == provider2.id
def test_list_mcp_tool_from_remote_server_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful listing of MCP tools from remote server.
@ -797,9 +784,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True # Provider must be authenticated to list tools
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
@ -821,9 +806,8 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
@ -834,7 +818,7 @@ class TestMCPToolManageService:
# Note: server_url is mocked, so we skip that assertion to avoid encryption issues
# Verify database state was updated
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True
assert mcp_provider.tools != "[]"
assert mcp_provider.updated_at is not None
@ -844,7 +828,7 @@ class TestMCPToolManageService:
mock_mcp_client.assert_called_once()
def test_list_mcp_tool_from_remote_server_auth_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP server requires authentication.
@ -871,9 +855,7 @@ class TestMCPToolManageService:
mcp_provider.authed = False
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
@ -887,19 +869,18 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="Please auth the tool first"):
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.tools == "[]"
def test_list_mcp_tool_from_remote_server_connection_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when MCP server connection fails.
@ -926,9 +907,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True # Provider must be authenticated to test connection errors
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the decryption process at the rsa level to avoid key file issues
with patch("libs.rsa.decrypt") as mock_decrypt:
@ -942,18 +921,17 @@ class TestMCPToolManageService:
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"):
service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Verify database state was not changed
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True # Provider remains authenticated
assert mcp_provider.tools == "[]"
def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful deletion of MCP tool.
@ -974,20 +952,19 @@ class TestMCPToolManageService:
)
# Verify provider exists
from extensions.ext_database import db
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None
# Act: Execute the method under test
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id)
# Assert: Verify the expected outcomes
# Provider should be deleted from database
deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first()
assert deleted_provider is None
def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test error handling when deleting non-existent MCP tool.
@ -1005,13 +982,14 @@ class TestMCPToolManageService:
non_existent_id = str(fake.uuid4())
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id)
def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies):
def test_delete_mcp_tool_tenant_isolation(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tenant isolation when deleting MCP tool.
@ -1036,18 +1014,16 @@ class TestMCPToolManageService:
)
# Act & Assert: Verify tenant isolation
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool not found"):
service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id)
# Verify provider still exists in tenant1
from extensions.ext_database import db
assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None
def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful update of MCP provider.
@ -1070,14 +1046,12 @@ class TestMCPToolManageService:
original_name = mcp_provider.name
original_icon = mcp_provider.icon
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.update_provider(
tenant_id=tenant.id,
provider_id=mcp_provider.id,
@ -1094,7 +1068,7 @@ class TestMCPToolManageService:
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.name == "Updated MCP Provider"
assert mcp_provider.server_identifier == "updated_identifier_123"
assert mcp_provider.timeout == 45.0
@ -1108,7 +1082,9 @@ class TestMCPToolManageService:
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_mcp_provider_duplicate_name(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when updating MCP provider with duplicate name.
@ -1134,15 +1110,12 @@ class TestMCPToolManageService:
)
provider2.name = "Second Provider"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
with pytest.raises(ValueError, match="MCP tool First Provider already exists"):
service.update_provider(
tenant_id=tenant.id,
@ -1160,7 +1133,7 @@ class TestMCPToolManageService:
)
def test_update_mcp_provider_credentials_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful update of MCP provider credentials.
@ -1185,9 +1158,7 @@ class TestMCPToolManageService:
mcp_provider.authed = False
mcp_provider.tools = "[]"
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the provider controller and encryption
with (
@ -1202,9 +1173,8 @@ class TestMCPToolManageService:
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.update_provider_credentials(
provider_id=mcp_provider.id,
tenant_id=tenant.id,
@ -1213,7 +1183,7 @@ class TestMCPToolManageService:
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is True
assert mcp_provider.updated_at is not None
@ -1225,7 +1195,7 @@ class TestMCPToolManageService:
assert "new_key" in credentials
def test_update_mcp_provider_credentials_not_authed(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test update of MCP provider credentials when not authenticated.
@ -1249,9 +1219,7 @@ class TestMCPToolManageService:
mcp_provider.authed = True
mcp_provider.tools = '[{"name": "test_tool"}]'
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the provider controller and encryption
with (
@ -1266,9 +1234,8 @@ class TestMCPToolManageService:
mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"}
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
service = MCPToolManageService(db_session_with_containers)
service.update_provider_credentials(
provider_id=mcp_provider.id,
tenant_id=tenant.id,
@ -1277,12 +1244,14 @@ class TestMCPToolManageService:
)
# Assert: Verify the expected outcomes
db.session.refresh(mcp_provider)
db_session_with_containers.refresh(mcp_provider)
assert mcp_provider.authed is False
assert mcp_provider.tools == "[]"
assert mcp_provider.updated_at is not None
def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_re_connect_mcp_provider_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful reconnection to MCP provider.
@ -1343,7 +1312,9 @@ class TestMCPToolManageService:
sse_read_timeout=mcp_provider.sse_read_timeout,
)
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_re_connect_mcp_provider_auth_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test reconnection to MCP provider when authentication fails.
@ -1385,7 +1356,7 @@ class TestMCPToolManageService:
assert result.encrypted_credentials == "{}"
def test_re_connect_mcp_provider_connection_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test reconnection to MCP provider when connection fails.

View File

@ -2,6 +2,7 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
@ -27,7 +28,7 @@ class TestToolTransformService:
}
def _create_test_tool_provider(
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api"
):
"""
Helper method to create a test tool provider for testing.
@ -89,14 +90,12 @@ class TestToolTransformService:
else:
raise ValueError(f"Unknown provider type: {provider_type}")
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
return provider
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful plugin icon URL generation.
@ -126,7 +125,7 @@ class TestToolTransformService:
assert result == expected_url
def test_get_plugin_icon_url_with_empty_console_url(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test plugin icon URL generation when CONSOLE_API_URL is empty.
@ -156,7 +155,7 @@ class TestToolTransformService:
assert result == expected_url
def test_get_tool_provider_icon_url_builtin_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for builtin providers.
@ -194,7 +193,7 @@ class TestToolTransformService:
assert result == expected_encoded
def test_get_tool_provider_icon_url_api_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for API providers.
@ -220,7 +219,7 @@ class TestToolTransformService:
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_api_invalid_json(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for API providers with invalid JSON.
@ -246,7 +245,7 @@ class TestToolTransformService:
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
def test_get_tool_provider_icon_url_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for workflow providers.
@ -271,7 +270,7 @@ class TestToolTransformService:
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_mcp_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for MCP providers.
@ -296,7 +295,7 @@ class TestToolTransformService:
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_unknown_type(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for unknown provider types.
@ -317,7 +316,9 @@ class TestToolTransformService:
# Assert: Verify the expected outcomes
assert result == ""
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_repack_provider_dict_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful provider repacking with dictionary input.
@ -341,7 +342,9 @@ class TestToolTransformService:
# Note: provider name may contain spaces that get URL encoded
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_repack_provider_entity_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful provider repacking with ToolProviderApiEntity input.
@ -389,7 +392,7 @@ class TestToolTransformService:
assert "test_icon_dark.png" in provider.icon_dark
def test_repack_provider_entity_no_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
@ -435,7 +438,9 @@ class TestToolTransformService:
assert provider.icon_dark["background"] == "#252525"
assert provider.icon_dark["content"] == "🔧"
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
def test_repack_provider_entity_no_dark_icon(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test provider repacking with ToolProviderApiEntity input without dark icon.
@ -477,7 +482,7 @@ class TestToolTransformService:
assert provider.icon_dark == ""
def test_builtin_provider_to_user_provider_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider.
@ -545,7 +550,7 @@ class TestToolTransformService:
assert result.original_credentials == {"api_key": "decrypted_key"}
def test_builtin_provider_to_user_provider_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider with plugin.
@ -589,7 +594,7 @@ class TestToolTransformService:
assert result.allow_delete is False
def test_builtin_provider_to_user_provider_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of builtin provider to user provider without credentials.
@ -630,7 +635,9 @@ class TestToolTransformService:
assert result.allow_delete is False
assert result.masked_credentials == {"api_key": ""}
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_api_provider_to_controller_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of API provider to controller.
@ -655,10 +662,8 @@ class TestToolTransformService:
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
@ -669,7 +674,7 @@ class TestToolTransformService:
# Additional assertions would depend on the actual controller implementation
def test_api_provider_to_controller_api_key_query(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with api_key_query auth type.
@ -693,10 +698,8 @@ class TestToolTransformService:
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
@ -706,7 +709,7 @@ class TestToolTransformService:
assert hasattr(result, "from_db")
def test_api_provider_to_controller_backward_compatibility(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with backward compatibility auth types.
@ -731,10 +734,8 @@ class TestToolTransformService:
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
@ -744,7 +745,7 @@ class TestToolTransformService:
assert hasattr(result, "from_db")
def test_workflow_provider_to_controller_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of workflow provider to controller.
@ -769,10 +770,8 @@ class TestToolTransformService:
parameter_configuration="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
db_session_with_containers.add(provider)
db_session_with_containers.commit()
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:

View File

@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from pydantic import ValidationError
from sqlalchemy.orm import Session
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
@ -63,7 +64,7 @@ class TestWorkflowToolManageService:
"tool_transform_service": mock_tool_transform_service,
}
def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test app and account for testing.
@ -119,14 +120,12 @@ class TestWorkflowToolManageService:
conversation_variables=[],
)
from extensions.ext_database import db
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Update app to reference the workflow
app.workflow_id = workflow.id
db.session.commit()
db_session_with_containers.commit()
return app, account, workflow
@ -153,7 +152,9 @@ class TestWorkflowToolManageService:
),
]
def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_create_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful workflow tool creation with valid parameters.
@ -198,11 +199,10 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
# Verify database state
from extensions.ext_database import db
# Check if workflow tool provider was created
created_tool_provider = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -230,7 +230,7 @@ class TestWorkflowToolManageService:
].workflow_provider_to_controller.assert_called_once()
def test_create_workflow_tool_duplicate_name_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when name already exists.
@ -280,10 +280,9 @@ class TestWorkflowToolManageService:
assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -293,7 +292,7 @@ class TestWorkflowToolManageService:
assert tool_count == 1
def test_create_workflow_tool_invalid_app_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app does not exist.
@ -331,10 +330,9 @@ class TestWorkflowToolManageService:
assert f"App {non_existent_app_id} not found" in str(exc_info.value)
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -344,7 +342,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_invalid_parameters_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when parameters are invalid.
@ -387,10 +385,9 @@ class TestWorkflowToolManageService:
assert "validation error" in str(exc_info.value).lower()
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -400,7 +397,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_duplicate_app_id_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app_id already exists.
@ -450,10 +447,9 @@ class TestWorkflowToolManageService:
assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value)
# Verify only one tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -463,7 +459,7 @@ class TestWorkflowToolManageService:
assert tool_count == 1
def test_create_workflow_tool_workflow_not_found_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when app has no workflow.
@ -481,10 +477,9 @@ class TestWorkflowToolManageService:
)
# Remove workflow reference from app
from extensions.ext_database import db
app.workflow_id = None
db.session.commit()
db_session_with_containers.commit()
# Attempt to create workflow tool for app without workflow
tool_parameters = self._create_test_workflow_tool_parameters()
@ -505,7 +500,7 @@ class TestWorkflowToolManageService:
# Verify no workflow tool was created
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -515,7 +510,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_create_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation fails when workflow contains human input nodes.
@ -558,10 +553,8 @@ class TestWorkflowToolManageService:
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -570,7 +563,9 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_workflow_tool_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful workflow tool update with valid parameters.
@ -603,10 +598,9 @@ class TestWorkflowToolManageService:
)
# Get the created tool
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -641,7 +635,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
# Verify database state was updated
db.session.refresh(created_tool)
db_session_with_containers.refresh(created_tool)
assert created_tool is not None
assert created_tool.name == updated_tool_name
assert created_tool.label == updated_tool_label
@ -658,7 +652,7 @@ class TestWorkflowToolManageService:
mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called()
def test_update_workflow_tool_human_input_node_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool update fails when workflow contains human input nodes.
@ -689,10 +683,8 @@ class TestWorkflowToolManageService:
parameters=initial_tool_parameters,
)
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -712,7 +704,7 @@ class TestWorkflowToolManageService:
]
}
)
db.session.commit()
db_session_with_containers.commit()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
WorkflowToolManageService.update_workflow_tool(
@ -728,10 +720,12 @@ class TestWorkflowToolManageService:
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
db.session.refresh(created_tool)
db_session_with_containers.refresh(created_tool)
assert created_tool.name == original_name
def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies):
def test_update_workflow_tool_not_found_error(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool update fails when tool does not exist.
@ -768,10 +762,9 @@ class TestWorkflowToolManageService:
assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value)
# Verify no workflow tool was created
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
)
@ -781,7 +774,7 @@ class TestWorkflowToolManageService:
assert tool_count == 0
def test_update_workflow_tool_same_name_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool update succeeds when keeping the same name.
@ -813,10 +806,9 @@ class TestWorkflowToolManageService:
)
# Get the created tool
from extensions.ext_database import db
created_tool = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.app_id == app.id,
@ -840,12 +832,12 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
# Verify tool still exists with the same name
db.session.refresh(created_tool)
db_session_with_containers.refresh(created_tool)
assert created_tool.name == first_tool_name
assert created_tool.updated_at is not None
def test_create_workflow_tool_with_file_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILE parameter having a file object as default.
@ -916,7 +908,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
def test_create_workflow_tool_with_files_parameter_default(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test workflow tool creation with FILES (Array[File]) parameter having file objects as default.
@ -991,7 +983,7 @@ class TestWorkflowToolManageService:
assert result == {"result": "success"}
def test_create_workflow_tool_db_commit_before_validation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that database commit happens before validation, causing DB pollution on validation failure.
@ -1035,10 +1027,9 @@ class TestWorkflowToolManageService:
# Verify the tool was NOT created in database
# This is the expected behavior (no pollution)
from extensions.ext_database import db
tool_count = (
db.session.query(WorkflowToolProvider)
db_session_with_containers.query(WorkflowToolProvider)
.where(
WorkflowToolProvider.tenant_id == account.current_tenant.id,
WorkflowToolProvider.name == tool_name,

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@ -79,7 +80,7 @@ class TestWorkflowConverter:
mock_config.app_model_config_dict = {}
return mock_config
def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Helper method to create a test account and tenant for testing.
@ -100,18 +101,16 @@ class TestWorkflowConverter:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
from models.account import TenantAccountJoin, TenantAccountRole
@ -122,15 +121,17 @@ class TestWorkflowConverter:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account):
def _create_test_app(
self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account
):
"""
Helper method to create a test app for testing.
@ -163,10 +164,8 @@ class TestWorkflowConverter:
updated_by=account.id,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
# Create app model config
app_model_config = AppModelConfig(
@ -177,16 +176,16 @@ class TestWorkflowConverter:
created_by=account.id,
updated_by=account.id,
)
db.session.add(app_model_config)
db.session.commit()
db_session_with_containers.add(app_model_config)
db_session_with_containers.commit()
# Link app model config to app
app.app_model_config_id = app_model_config.id
db.session.commit()
db_session_with_containers.commit()
return app
def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
Test successful conversion of app to workflow.
@ -225,19 +224,18 @@ class TestWorkflowConverter:
assert new_app.created_by == account.id
# Verify database state
from extensions.ext_database import db
db.session.refresh(new_app)
db_session_with_containers.refresh(new_app)
assert new_app.id is not None
# Verify workflow was created
workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first()
workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first()
assert workflow is not None
assert workflow.tenant_id == app.tenant_id
assert workflow.type == "chat"
def test_convert_to_workflow_without_app_model_config_error(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling when app model config is missing.
@ -270,16 +268,14 @@ class TestWorkflowConverter:
updated_by=account.id,
)
from extensions.ext_database import db
db.session.add(app)
db.session.commit()
db_session_with_containers.add(app)
db_session_with_containers.commit()
# Act & Assert: Verify proper error handling
workflow_converter = WorkflowConverter()
# Check initial state
initial_workflow_count = db.session.query(Workflow).count()
initial_workflow_count = db_session_with_containers.query(Workflow).count()
with pytest.raises(ValueError, match="App model config is required"):
workflow_converter.convert_to_workflow(
@ -294,12 +290,12 @@ class TestWorkflowConverter:
# Verify database state remains unchanged
# The workflow creation happens in convert_app_model_config_to_workflow
# which is called before the app_model_config check, so we need to clean up
db.session.rollback()
final_workflow_count = db.session.query(Workflow).count()
db_session_with_containers.rollback()
final_workflow_count = db_session_with_containers.query(Workflow).count()
assert final_workflow_count == initial_workflow_count
def test_convert_app_model_config_to_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion of app model config to workflow.
@ -356,16 +352,17 @@ class TestWorkflowConverter:
assert answer_node["id"] == "answer"
# Verify database state
from extensions.ext_database import db
db.session.refresh(workflow)
db_session_with_containers.refresh(workflow)
assert workflow.id is not None
# Verify features were set
features = json.loads(workflow._features) if workflow._features else {}
assert isinstance(features, dict)
def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_convert_to_start_node_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion to start node.
@ -410,7 +407,9 @@ class TestWorkflowConverter:
assert second_variable["label"] == "Number Input"
assert second_variable["type"] == "number"
def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_convert_to_http_request_node_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion to HTTP request node.
@ -436,10 +435,8 @@ class TestWorkflowConverter:
api_endpoint="https://api.example.com/test",
)
from extensions.ext_database import db
db.session.add(api_based_extension)
db.session.commit()
db_session_with_containers.add(api_based_extension)
db_session_with_containers.commit()
# Mock encrypter
mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key"
@ -489,7 +486,7 @@ class TestWorkflowConverter:
assert external_data_variable_node_mapping["external_data"] == code_node["id"]
def test_convert_to_knowledge_retrieval_node_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful conversion to knowledge retrieval node.

View File

@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment
@ -31,7 +31,9 @@ class TestAddDocumentToIndexTask:
"index_processor": mock_processor,
}
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_dataset_and_document(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Helper method to create a test dataset and document for testing.
@ -51,15 +53,15 @@ class TestAddDocumentToIndexTask:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -68,8 +70,8 @@ class TestAddDocumentToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -81,8 +83,8 @@ class TestAddDocumentToIndexTask:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create document
document = Document(
@ -99,15 +101,15 @@ class TestAddDocumentToIndexTask:
enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property works correctly
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, document
def _create_test_segments(self, db_session_with_containers, document, dataset):
def _create_test_segments(self, db_session_with_containers: Session, document, dataset):
"""
Helper method to create test document segments.
@ -138,13 +140,15 @@ class TestAddDocumentToIndexTask:
status="completed",
created_by=document.created_by,
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
return segments
def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_add_document_to_index_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful document indexing with paragraph index type.
@ -180,9 +184,9 @@ class TestAddDocumentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_called_once()
# Verify database state changes
db.session.refresh(document)
db_session_with_containers.refresh(document)
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
@ -191,7 +195,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_different_index_type(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test document indexing with different index types.
@ -209,10 +213,10 @@ class TestAddDocumentToIndexTask:
# Update document to use different index type
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
# Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -237,9 +241,9 @@ class TestAddDocumentToIndexTask:
assert len(documents) == 3
# Verify database state changes
db.session.refresh(document)
db_session_with_containers.refresh(document)
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
@ -248,7 +252,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of non-existent document.
@ -275,7 +279,7 @@ class TestAddDocumentToIndexTask:
# because indexing_cache_key is not defined in that case
def test_add_document_to_index_invalid_indexing_status(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of document with invalid indexing status.
@ -294,7 +298,7 @@ class TestAddDocumentToIndexTask:
# Set invalid indexing status
document.indexing_status = "processing"
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the task
add_document_to_index_task(document.id)
@ -304,7 +308,7 @@ class TestAddDocumentToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_add_document_to_index_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling when document's dataset doesn't exist.
@ -326,14 +330,14 @@ class TestAddDocumentToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300)
# Delete the dataset to simulate dataset not found scenario
db.session.delete(dataset)
db.session.commit()
db_session_with_containers.delete(dataset)
db_session_with_containers.commit()
# Act: Execute the task
add_document_to_index_task(document.id)
# Assert: Verify error handling
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.error is not None
@ -348,7 +352,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_parent_child_structure(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test document indexing with parent-child structure.
@ -367,10 +371,10 @@ class TestAddDocumentToIndexTask:
# Update document to use parent-child index type
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
# Create segments with mock child chunks
segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -413,9 +417,9 @@ class TestAddDocumentToIndexTask:
assert len(doc.children) == 2 # Each document has 2 children
# Verify database state changes
db.session.refresh(document)
db_session_with_containers.refresh(document)
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is True
assert segment.disabled_at is None
assert segment.disabled_by is None
@ -424,7 +428,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_with_already_enabled_segments(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test document indexing when segments are already enabled.
@ -459,10 +463,10 @@ class TestAddDocumentToIndexTask:
status="completed",
created_by=document.created_by,
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
@ -488,7 +492,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_auto_disable_log_deletion(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test that auto disable logs are properly deleted during indexing.
@ -515,10 +519,10 @@ class TestAddDocumentToIndexTask:
document_id=document.id,
)
log_entry.id = str(fake.uuid4())
db.session.add(log_entry)
db_session_with_containers.add(log_entry)
auto_disable_logs.append(log_entry)
db.session.commit()
db_session_with_containers.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
@ -526,7 +530,9 @@ class TestAddDocumentToIndexTask:
# Verify logs exist before processing
existing_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all()
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
assert len(existing_logs) == 2
@ -535,7 +541,9 @@ class TestAddDocumentToIndexTask:
# Assert: Verify auto disable logs were deleted
remaining_logs = (
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all()
db_session_with_containers.query(DatasetAutoDisableLog)
.where(DatasetAutoDisableLog.document_id == document.id)
.all()
)
assert len(remaining_logs) == 0
@ -547,14 +555,14 @@ class TestAddDocumentToIndexTask:
# Verify segments were enabled
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is True
# Verify redis cache was cleared
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_general_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test general exception handling during indexing process.
@ -584,7 +592,7 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify error handling
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.enabled is False
assert document.indexing_status == "error"
assert document.error is not None
@ -593,14 +601,14 @@ class TestAddDocumentToIndexTask:
# Verify segments were not enabled due to error
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is False # Should remain disabled due to error
# Verify redis cache was still cleared despite error
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_segment_filtering_edge_cases(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segment filtering with various edge cases.
@ -638,7 +646,7 @@ class TestAddDocumentToIndexTask:
status="completed",
created_by=document.created_by,
)
db.session.add(segment1)
db_session_with_containers.add(segment1)
segments.append(segment1)
# Segment 2: Should be processed (enabled=True, status="completed")
@ -658,7 +666,7 @@ class TestAddDocumentToIndexTask:
status="completed",
created_by=document.created_by,
)
db.session.add(segment2)
db_session_with_containers.add(segment2)
segments.append(segment2)
# Segment 3: Should NOT be processed (enabled=False, status="processing")
@ -677,7 +685,7 @@ class TestAddDocumentToIndexTask:
status="processing", # Not completed
created_by=document.created_by,
)
db.session.add(segment3)
db_session_with_containers.add(segment3)
segments.append(segment3)
# Segment 4: Should be processed (enabled=False, status="completed")
@ -696,10 +704,10 @@ class TestAddDocumentToIndexTask:
status="completed",
created_by=document.created_by,
)
db.session.add(segment4)
db_session_with_containers.add(segment4)
segments.append(segment4)
db.session.commit()
db_session_with_containers.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
@ -728,11 +736,11 @@ class TestAddDocumentToIndexTask:
assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3
# Verify database state changes
db.session.refresh(document)
db.session.refresh(segment1)
db.session.refresh(segment2)
db.session.refresh(segment3)
db.session.refresh(segment4)
db_session_with_containers.refresh(document)
db_session_with_containers.refresh(segment1)
db_session_with_containers.refresh(segment2)
db_session_with_containers.refresh(segment3)
db_session_with_containers.refresh(segment4)
# All segments should be enabled because the task updates ALL segments for the document
assert segment1.enabled is True
@ -744,7 +752,7 @@ class TestAddDocumentToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_add_document_to_index_comprehensive_error_scenarios(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test comprehensive error scenarios and recovery.
@ -779,7 +787,7 @@ class TestAddDocumentToIndexTask:
document.indexing_status = "completed"
document.error = None
document.disabled_at = None
db.session.commit()
db_session_with_containers.commit()
# Set up Redis cache key
indexing_cache_key = f"document_{document.id}_indexing"
@ -789,7 +797,7 @@ class TestAddDocumentToIndexTask:
add_document_to_index_task(document.id)
# Assert: Verify consistent error handling
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.enabled is False, f"Document should be disabled for {error_name}"
assert document.indexing_status == "error", f"Document status should be error for {error_name}"
assert document.error is not None, f"Error should be recorded for {error_name}"
@ -798,7 +806,7 @@ class TestAddDocumentToIndexTask:
# Verify segments remain disabled due to error
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is False, f"Segments should remain disabled for {error_name}"
# Verify redis cache was still cleared despite error

View File

@ -11,8 +11,8 @@ from unittest.mock import Mock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -49,7 +49,7 @@ class TestBatchCleanDocumentTask:
"get_image_ids": mock_get_image_ids,
}
def _create_test_account(self, db_session_with_containers):
def _create_test_account(self, db_session_with_containers: Session):
"""
Helper method to create a test account for testing.
@ -69,16 +69,16 @@ class TestBatchCleanDocumentTask:
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -87,15 +87,15 @@ class TestBatchCleanDocumentTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account
def _create_test_dataset(self, db_session_with_containers, account):
def _create_test_dataset(self, db_session_with_containers: Session, account):
"""
Helper method to create a test dataset for testing.
@ -119,12 +119,12 @@ class TestBatchCleanDocumentTask:
embedding_model_provider="openai",
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_test_document(self, db_session_with_containers, dataset, account):
def _create_test_document(self, db_session_with_containers: Session, dataset, account):
"""
Helper method to create a test document for testing.
@ -153,12 +153,12 @@ class TestBatchCleanDocumentTask:
doc_form="text_model",
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def _create_test_document_segment(self, db_session_with_containers, document, account):
def _create_test_document_segment(self, db_session_with_containers: Session, document, account):
"""
Helper method to create a test document segment for testing.
@ -186,12 +186,12 @@ class TestBatchCleanDocumentTask:
status="completed",
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
def _create_test_upload_file(self, db_session_with_containers, account):
def _create_test_upload_file(self, db_session_with_containers: Session, account):
"""
Helper method to create a test upload file for testing.
@ -220,13 +220,13 @@ class TestBatchCleanDocumentTask:
used=False,
)
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
def test_batch_clean_document_task_successful_cleanup(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful cleanup of documents with segments and files.
@ -245,7 +245,7 @@ class TestBatchCleanDocumentTask:
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Store original IDs for verification
document_id = document.id
@ -261,18 +261,18 @@ class TestBatchCleanDocumentTask:
# The task should have processed the segment and cleaned up the database
# Verify database cleanup
db.session.commit() # Ensure all changes are committed
db_session_with_containers.commit() # Ensure all changes are committed
# Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_with_image_files(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup of documents containing image references.
@ -300,8 +300,8 @@ class TestBatchCleanDocumentTask:
status="completed",
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Store original IDs for verification
segment_id = segment.id
@ -313,17 +313,17 @@ class TestBatchCleanDocumentTask:
)
# Verify database cleanup
db.session.commit()
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Verify that the task completed successfully by checking the log output
# The task should have processed the segment and cleaned up the database
def test_batch_clean_document_task_no_segments(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup when document has no segments.
@ -339,7 +339,7 @@ class TestBatchCleanDocumentTask:
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Store original IDs for verification
document_id = document.id
@ -354,21 +354,21 @@ class TestBatchCleanDocumentTask:
# Since there are no segments, the task should handle this gracefully
# Verify database cleanup
db.session.commit()
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
# Verify database cleanup
db.session.commit()
db_session_with_containers.commit()
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup when dataset is not found.
@ -386,8 +386,8 @@ class TestBatchCleanDocumentTask:
dataset_id = dataset.id
# Delete the dataset to simulate not found scenario
db.session.delete(dataset)
db.session.commit()
db_session_with_containers.delete(dataset)
db_session_with_containers.commit()
# Execute the task with non-existent dataset
batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[])
@ -399,14 +399,14 @@ class TestBatchCleanDocumentTask:
mock_external_service_dependencies["storage"].delete.assert_not_called()
# Verify that no database cleanup occurred
db.session.commit()
db_session_with_containers.commit()
# Document should still exist since cleanup failed
existing_document = db.session.query(Document).filter_by(id=document_id).first()
existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first()
assert existing_document is not None
def test_batch_clean_document_task_storage_cleanup_failure(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup when storage operations fail.
@ -423,7 +423,7 @@ class TestBatchCleanDocumentTask:
# Update document to reference the upload file
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Store original IDs for verification
document_id = document.id
@ -442,18 +442,18 @@ class TestBatchCleanDocumentTask:
# The task should continue processing even when storage operations fail
# Verify database cleanup still occurred despite storage failure
db.session.commit()
db_session_with_containers.commit()
# Check that segment is deleted from database
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that upload file is deleted from database
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_multiple_documents(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup of multiple documents in a single batch operation.
@ -482,7 +482,7 @@ class TestBatchCleanDocumentTask:
segments.append(segment)
upload_files.append(upload_file)
db.session.commit()
db_session_with_containers.commit()
# Store original IDs for verification
document_ids = [doc.id for doc in documents]
@ -498,20 +498,20 @@ class TestBatchCleanDocumentTask:
# The task should process all documents and clean up all associated resources
# Verify database cleanup for all resources
db.session.commit()
db_session_with_containers.commit()
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_different_doc_forms(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup with different document form types.
@ -527,12 +527,12 @@ class TestBatchCleanDocumentTask:
for doc_form in doc_forms:
dataset = self._create_test_dataset(db_session_with_containers, account)
db.session.commit()
db_session_with_containers.commit()
document = self._create_test_document(db_session_with_containers, dataset, account)
# Update document doc_form
document.doc_form = doc_form
db.session.commit()
db_session_with_containers.commit()
segment = self._create_test_document_segment(db_session_with_containers, document, account)
@ -549,20 +549,20 @@ class TestBatchCleanDocumentTask:
# The task should handle different document forms correctly
# Verify database cleanup
db.session.commit()
db_session_with_containers.commit()
# Check that segment is deleted
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
except Exception as e:
# If the task fails due to external service issues (e.g., plugin daemon),
# we should still verify that the database state is consistent
# This is a common scenario in test environments where external services may not be available
db.session.commit()
db_session_with_containers.commit()
# Check if the segment still exists (task may have failed before deletion)
existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
if existing_segment is not None:
# If segment still exists, the task failed before deletion
# This is acceptable in test environments with external service issues
@ -572,7 +572,7 @@ class TestBatchCleanDocumentTask:
pass
def test_batch_clean_document_task_large_batch_performance(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test cleanup performance with a large batch of documents.
@ -604,7 +604,7 @@ class TestBatchCleanDocumentTask:
segments.append(segment)
upload_files.append(upload_file)
db.session.commit()
db_session_with_containers.commit()
# Store original IDs for verification
document_ids = [doc.id for doc in documents]
@ -629,20 +629,20 @@ class TestBatchCleanDocumentTask:
# The task should handle large batches efficiently
# Verify database cleanup for all resources
db.session.commit()
db_session_with_containers.commit()
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that all upload files are deleted
for file_id in file_ids:
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
def test_batch_clean_document_task_integration_with_real_database(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test full integration with real database operations.
@ -683,12 +683,12 @@ class TestBatchCleanDocumentTask:
# Add all to database
for segment in segments:
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Verify initial state
assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3
assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None
# Store original IDs for verification
document_id = document.id
@ -704,17 +704,17 @@ class TestBatchCleanDocumentTask:
# The task should process all segments and clean up all associated resources
# Verify database cleanup
db.session.commit()
db_session_with_containers.commit()
# Check that all segments are deleted
for segment_id in segment_ids:
deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first()
deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first()
assert deleted_segment is None
# Check that upload file is deleted
deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first()
deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first()
assert deleted_file is None
# Verify final database state
assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db.session.query(UploadFile).filter_by(id=file_id).first() is None
assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0
assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None

View File

@ -17,6 +17,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -29,20 +30,19 @@ class TestBatchCreateSegmentToIndexTask:
"""Integration tests for batch_create_segment_to_index_task using testcontainers."""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_database import db
from extensions.ext_redis import redis_client
# Clear all test data
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(UploadFile).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -75,7 +75,7 @@ class TestBatchCreateSegmentToIndexTask:
"embedding_model": mock_embedding_model,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
"""
Helper method to create a test account and tenant for testing.
@ -95,18 +95,16 @@ class TestBatchCreateSegmentToIndexTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant for the account
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -115,15 +113,15 @@ class TestBatchCreateSegmentToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_dataset(self, db_session_with_containers, account, tenant):
def _create_test_dataset(self, db_session_with_containers: Session, account, tenant):
"""
Helper method to create a test dataset for testing.
@ -148,14 +146,12 @@ class TestBatchCreateSegmentToIndexTask:
created_by=account.id,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_test_document(self, db_session_with_containers, account, tenant, dataset):
def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset):
"""
Helper method to create a test document for testing.
@ -186,14 +182,12 @@ class TestBatchCreateSegmentToIndexTask:
word_count=0,
)
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def _create_test_upload_file(self, db_session_with_containers, account, tenant):
def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant):
"""
Helper method to create a test upload file for testing.
@ -221,10 +215,8 @@ class TestBatchCreateSegmentToIndexTask:
used=False,
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
@ -252,7 +244,7 @@ class TestBatchCreateSegmentToIndexTask:
return csv_content
def test_batch_create_segment_to_index_task_success_text_model(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful batch creation of segments for text model documents.
@ -293,11 +285,10 @@ class TestBatchCreateSegmentToIndexTask:
)
# Verify results
from extensions.ext_database import db
# Check that segments were created
segments = (
db.session.query(DocumentSegment)
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
@ -316,7 +307,7 @@ class TestBatchCreateSegmentToIndexTask:
assert segment.answer is None # text_model doesn't have answers
# Check that document word count was updated
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.word_count > 0
# Verify vector service was called
@ -331,7 +322,7 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"completed"
def test_batch_create_segment_to_index_task_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test task failure when dataset does not exist.
@ -370,17 +361,16 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created (since dataset doesn't exist)
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()
segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0
# Verify no documents were modified
documents = db.session.query(Document).all()
documents = db_session_with_containers.query(Document).all()
assert len(documents) == 0
def test_batch_create_segment_to_index_task_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test task failure when document does not exist.
@ -419,18 +409,17 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()
segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0
# Verify dataset remains unchanged (no segments were added to the dataset)
db.session.refresh(dataset)
segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
db_session_with_containers.refresh(dataset)
segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(segments_for_dataset) == 0
def test_batch_create_segment_to_index_task_document_not_available(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test task failure when document is not available for indexing.
@ -498,11 +487,9 @@ class TestBatchCreateSegmentToIndexTask:
),
]
from extensions.ext_database import db
for document in test_cases:
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Test each unavailable document
for document in test_cases:
@ -524,11 +511,11 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created
segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all()
segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all()
assert len(segments) == 0
def test_batch_create_segment_to_index_task_upload_file_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test task failure when upload file does not exist.
@ -567,17 +554,16 @@ class TestBatchCreateSegmentToIndexTask:
assert cache_value == b"error"
# Verify no segments were created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()
segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0
# Verify document remains unchanged
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.word_count == 0
def test_batch_create_segment_to_index_task_empty_csv_file(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test task failure when CSV file is empty.
@ -619,17 +605,16 @@ class TestBatchCreateSegmentToIndexTask:
# Verify error handling
# Since exception was raised, no segments should be created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()
segments = db_session_with_containers.query(DocumentSegment).all()
assert len(segments) == 0
# Verify document remains unchanged
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.word_count == 0
def test_batch_create_segment_to_index_task_position_calculation(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test proper position calculation for segments when existing segments exist.
@ -664,11 +649,9 @@ class TestBatchCreateSegmentToIndexTask:
)
existing_segments.append(segment)
from extensions.ext_database import db
for segment in existing_segments:
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Create CSV content
csv_content = self._create_test_csv_content("text_model")
@ -695,7 +678,7 @@ class TestBatchCreateSegmentToIndexTask:
# Verify results
# Check that new segments were created with correct positions
all_segments = (
db.session.query(DocumentSegment)
db_session_with_containers.query(DocumentSegment)
.filter_by(document_id=document.id)
.order_by(DocumentSegment.position)
.all()
@ -716,7 +699,7 @@ class TestBatchCreateSegmentToIndexTask:
assert segment.completed_at is not None
# Check that document word count was updated
db.session.refresh(document)
db_session_with_containers.refresh(document)
assert document.word_count > 0
# Verify vector service was called

View File

@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import (
@ -37,7 +38,7 @@ class TestCleanDatasetTask:
"""Integration tests for clean_dataset_task using testcontainers."""
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
def cleanup_database(self, db_session_with_containers: Session):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_redis import redis_client
@ -82,7 +83,7 @@ class TestCleanDatasetTask:
"index_processor": mock_index_processor,
}
def _create_test_account_and_tenant(self, db_session_with_containers):
def _create_test_account_and_tenant(self, db_session_with_containers: Session):
"""
Helper method to create a test account and tenant for testing.
@ -127,7 +128,7 @@ class TestCleanDatasetTask:
return account, tenant
def _create_test_dataset(self, db_session_with_containers, account, tenant):
def _create_test_dataset(self, db_session_with_containers: Session, account, tenant):
"""
Helper method to create a test dataset for testing.
@ -157,7 +158,7 @@ class TestCleanDatasetTask:
return dataset
def _create_test_document(self, db_session_with_containers, account, tenant, dataset):
def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset):
"""
Helper method to create a test document for testing.
@ -194,7 +195,7 @@ class TestCleanDatasetTask:
return document
def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document):
def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document):
"""
Helper method to create a test document segment for testing.
@ -230,7 +231,7 @@ class TestCleanDatasetTask:
return segment
def _create_test_upload_file(self, db_session_with_containers, account, tenant):
def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant):
"""
Helper method to create a test upload file for testing.
@ -264,7 +265,7 @@ class TestCleanDatasetTask:
return upload_file
def test_clean_dataset_task_success_basic_cleanup(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful basic dataset cleanup with minimal data.
@ -325,7 +326,7 @@ class TestCleanDatasetTask:
mock_storage.delete.assert_not_called()
def test_clean_dataset_task_success_with_documents_and_segments(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful dataset cleanup with documents and segments.
@ -433,7 +434,7 @@ class TestCleanDatasetTask:
assert mock_storage.delete.call_count == 3
def test_clean_dataset_task_success_with_invalid_doc_form(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful dataset cleanup with invalid doc_form handling.
@ -493,7 +494,7 @@ class TestCleanDatasetTask:
assert mock_factory.call_count == 4
def test_clean_dataset_task_error_handling_and_rollback(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test error handling and rollback mechanism when database operations fail.
@ -542,7 +543,7 @@ class TestCleanDatasetTask:
# This demonstrates the resilience of the cleanup process
def test_clean_dataset_task_with_image_file_references(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test dataset cleanup with image file references in document segments.
@ -634,7 +635,7 @@ class TestCleanDatasetTask:
mock_get_image_ids.assert_called_once()
def test_clean_dataset_task_performance_with_large_dataset(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test dataset cleanup performance with large amounts of data.
@ -704,11 +705,9 @@ class TestCleanDatasetTask:
binding.created_at = datetime.now()
bindings.append(binding)
from extensions.ext_database import db
db.session.add_all(metadata_items)
db.session.add_all(bindings)
db.session.commit()
db_session_with_containers.add_all(metadata_items)
db_session_with_containers.add_all(bindings)
db_session_with_containers.commit()
# Measure cleanup performance
import time
@ -772,7 +771,7 @@ class TestCleanDatasetTask:
print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds")
def test_clean_dataset_task_storage_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test dataset cleanup when storage operations fail.
@ -838,7 +837,7 @@ class TestCleanDatasetTask:
# consistency in the database
def test_clean_dataset_task_edge_cases_and_boundary_conditions(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test dataset cleanup with edge cases and boundary conditions.

View File

@ -13,8 +13,8 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -34,7 +34,7 @@ class TestDisableSegmentFromIndexTask:
mock_processor.clean.return_value = None
yield mock_processor
def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]:
def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]:
"""
Helper method to create a test account and tenant for testing.
@ -53,8 +53,8 @@ class TestDisableSegmentFromIndexTask:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -62,8 +62,8 @@ class TestDisableSegmentFromIndexTask:
status="normal",
plan="basic",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
@ -72,15 +72,15 @@ class TestDisableSegmentFromIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
return account, tenant
def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset:
def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset:
"""
Helper method to create a test dataset.
@ -101,13 +101,18 @@ class TestDisableSegmentFromIndexTask:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
def _create_test_document(
self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model"
self,
db_session_with_containers: Session,
dataset: Dataset,
tenant: Tenant,
account: Account,
doc_form: str = "text_model",
) -> Document:
"""
Helper method to create a test document.
@ -140,13 +145,14 @@ class TestDisableSegmentFromIndexTask:
tokens=500,
completed_at=datetime.now(UTC),
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
def _create_test_segment(
self,
db_session_with_containers: Session,
document: Document,
dataset: Dataset,
tenant: Tenant,
@ -185,12 +191,12 @@ class TestDisableSegmentFromIndexTask:
created_by=account.id,
completed_at=datetime.now(UTC) if status == "completed" else None,
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
def test_disable_segment_success(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor):
"""
Test successful segment disabling from index.
@ -202,9 +208,9 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Set up Redis cache
indexing_cache_key = f"segment_{segment.id}_indexing"
@ -226,10 +232,10 @@ class TestDisableSegmentFromIndexTask:
assert redis_client.get(indexing_cache_key) is None
# Verify segment is still in database
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.id is not None
def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when segment is not found.
@ -251,7 +257,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when segment is not in completed status.
@ -262,9 +268,11 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data with non-completed segment
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(
db_session_with_containers, document, dataset, tenant, account, status="indexing", enabled=True
)
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -275,7 +283,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when segment has no associated dataset.
@ -286,13 +294,13 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Manually remove dataset association
segment.dataset_id = "00000000-0000-0000-0000-000000000000"
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -303,7 +311,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when segment has no associated document.
@ -314,13 +322,13 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Manually remove document association
segment.document_id = "00000000-0000-0000-0000-000000000000"
db.session.commit()
db_session_with_containers.commit()
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -331,7 +339,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when document is disabled.
@ -342,12 +350,12 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data with disabled document
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
document.enabled = False
db.session.commit()
db_session_with_containers.commit()
segment = self._create_test_segment(document, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -358,7 +366,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when document is archived.
@ -369,12 +377,12 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data with archived document
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
document.archived = True
db.session.commit()
db_session_with_containers.commit()
segment = self._create_test_segment(document, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -385,7 +393,9 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_document_indexing_not_completed(
self, db_session_with_containers: Session, mock_index_processor
):
"""
Test handling when document indexing is not completed.
@ -396,12 +406,12 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data with incomplete indexing
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
document.indexing_status = "indexing"
db.session.commit()
db_session_with_containers.commit()
segment = self._create_test_segment(document, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -412,7 +422,7 @@ class TestDisableSegmentFromIndexTask:
# Verify index processor was not called
mock_index_processor.clean.assert_not_called()
def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor):
"""
Test handling when index processor raises an exception.
@ -424,9 +434,9 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Set up Redis cache
indexing_cache_key = f"segment_{segment.id}_indexing"
@ -449,13 +459,13 @@ class TestDisableSegmentFromIndexTask:
assert call_args[0][1] == [segment.index_node_id] # Check index node IDs
# Verify segment was re-enabled
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is True
# Verify Redis cache was still cleared
assert redis_client.get(indexing_cache_key) is None
def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor):
"""
Test disabling segments with different document forms.
@ -470,9 +480,11 @@ class TestDisableSegmentFromIndexTask:
for doc_form in doc_forms:
# Arrange: Create test data for each form
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account, doc_form=doc_form)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(
db_session_with_containers, dataset, tenant, account, doc_form=doc_form
)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Reset mock for each iteration
mock_index_processor.reset_mock()
@ -489,7 +501,7 @@ class TestDisableSegmentFromIndexTask:
assert call_args[0][0].id == dataset.id # Check dataset ID
assert call_args[0][1] == [segment.index_node_id] # Check index node IDs
def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor):
"""
Test Redis cache handling during segment disabling.
@ -500,9 +512,9 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Test with cache present
indexing_cache_key = f"segment_{segment.id}_indexing"
@ -517,13 +529,13 @@ class TestDisableSegmentFromIndexTask:
assert redis_client.get(indexing_cache_key) is None
# Test with no cache present
segment2 = self._create_test_segment(document, dataset, tenant, account)
segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
result2 = disable_segment_from_index_task(segment2.id)
# Assert: Verify task still works without cache
assert result2 is None
def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor):
"""
Test performance timing of segment disabling task.
@ -534,9 +546,9 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task and measure time
start_time = time.perf_counter()
@ -548,7 +560,9 @@ class TestDisableSegmentFromIndexTask:
execution_time = end_time - start_time
assert execution_time < 5.0 # Should complete within 5 seconds
def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_database_session_management(
self, db_session_with_containers: Session, mock_index_processor
):
"""
Test database session management during task execution.
@ -559,9 +573,9 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create test data
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
segment = self._create_test_segment(document, dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
# Act: Execute the task
result = disable_segment_from_index_task(segment.id)
@ -570,10 +584,10 @@ class TestDisableSegmentFromIndexTask:
assert result is None
# Verify segment is still accessible (session was properly managed)
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.id is not None
def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor):
def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor):
"""
Test concurrent execution of segment disabling tasks.
@ -584,12 +598,12 @@ class TestDisableSegmentFromIndexTask:
"""
# Arrange: Create multiple test segments
account, tenant = self._create_test_account_and_tenant(db_session_with_containers)
dataset = self._create_test_dataset(tenant, account)
document = self._create_test_document(dataset, tenant, account)
dataset = self._create_test_dataset(db_session_with_containers, tenant, account)
document = self._create_test_document(db_session_with_containers, dataset, tenant, account)
segments = []
for i in range(3):
segment = self._create_test_segment(document, dataset, tenant, account)
segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account)
segments.append(segment)
# Act: Execute tasks concurrently (simulated)

View File

@ -9,6 +9,7 @@ The task is responsible for removing document segments from the search index whe
from unittest.mock import MagicMock, patch
from faker import Faker
from sqlalchemy.orm import Session
from models import Account, Dataset, DocumentSegment
from models import Document as DatasetDocument
@ -31,7 +32,7 @@ class TestDisableSegmentsFromIndexTask:
and realistic testing environment with actual database interactions.
"""
def _create_test_account(self, db_session_with_containers, fake=None):
def _create_test_account(self, db_session_with_containers: Session, fake=None):
"""
Helper method to create a test account with realistic data.
@ -79,7 +80,7 @@ class TestDisableSegmentsFromIndexTask:
return account
def _create_test_dataset(self, db_session_with_containers, account, fake=None):
def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None):
"""
Helper method to create a test dataset with realistic data.
@ -113,7 +114,7 @@ class TestDisableSegmentsFromIndexTask:
return dataset
def _create_test_document(self, db_session_with_containers, dataset, account, fake=None):
def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None):
"""
Helper method to create a test document with realistic data.
@ -158,7 +159,9 @@ class TestDisableSegmentsFromIndexTask:
return document
def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None):
def _create_test_segments(
self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None
):
"""
Helper method to create test document segments with realistic data.
@ -210,7 +213,7 @@ class TestDisableSegmentsFromIndexTask:
return segments
def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None):
def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None):
"""
Helper method to create a dataset process rule.
@ -239,14 +242,12 @@ class TestDisableSegmentsFromIndexTask:
process_rule.created_by = dataset.created_by
process_rule.updated_by = dataset.updated_by
from extensions.ext_database import db
db.session.add(process_rule)
db.session.commit()
db_session_with_containers.add(process_rule)
db_session_with_containers.commit()
return process_rule
def test_disable_segments_success(self, db_session_with_containers):
def test_disable_segments_success(self, db_session_with_containers: Session):
"""
Test successful disabling of segments from index.
@ -297,7 +298,7 @@ class TestDisableSegmentsFromIndexTask:
expected_key = f"segment_{segment.id}_indexing"
mock_redis.delete.assert_any_call(expected_key)
def test_disable_segments_dataset_not_found(self, db_session_with_containers):
def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session):
"""
Test handling when dataset is not found.
@ -320,7 +321,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when dataset is not found
mock_redis.delete.assert_not_called()
def test_disable_segments_document_not_found(self, db_session_with_containers):
def test_disable_segments_document_not_found(self, db_session_with_containers: Session):
"""
Test handling when document is not found.
@ -344,7 +345,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when document is not found
mock_redis.delete.assert_not_called()
def test_disable_segments_document_invalid_status(self, db_session_with_containers):
def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session):
"""
Test handling when document has invalid status for disabling.
@ -360,9 +361,8 @@ class TestDisableSegmentsFromIndexTask:
# Test case 1: Document not enabled
document.enabled = False
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
segment_ids = [segment.id for segment in segments]
@ -379,7 +379,7 @@ class TestDisableSegmentsFromIndexTask:
# Test case 2: Document archived
document.enabled = True
document.archived = True
db.session.commit()
db_session_with_containers.commit()
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
@ -393,7 +393,7 @@ class TestDisableSegmentsFromIndexTask:
document.enabled = True
document.archived = False
document.indexing_status = "indexing"
db.session.commit()
db_session_with_containers.commit()
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
# Act
@ -403,7 +403,7 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value
mock_redis.delete.assert_not_called()
def test_disable_segments_no_segments_found(self, db_session_with_containers):
def test_disable_segments_no_segments_found(self, db_session_with_containers: Session):
"""
Test handling when no segments are found for the given IDs.
@ -430,7 +430,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when no segments are found
mock_redis.delete.assert_not_called()
def test_disable_segments_index_processor_error(self, db_session_with_containers):
def test_disable_segments_index_processor_error(self, db_session_with_containers: Session):
"""
Test handling when index processor encounters an error.
@ -464,13 +464,14 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value
# Verify segments were rolled back to enabled state
from extensions.ext_database import db
db.session.refresh(segments[0])
db.session.refresh(segments[1])
db_session_with_containers.refresh(segments[0])
db_session_with_containers.refresh(segments[1])
# Check that segments are re-enabled after error
updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
updated_segments = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all()
)
for segment in updated_segments:
assert segment.enabled is True
@ -480,7 +481,7 @@ class TestDisableSegmentsFromIndexTask:
# Verify Redis cache cleanup was still called
assert mock_redis.delete.call_count == len(segments)
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers):
def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session):
"""
Test disabling segments with different document forms.
@ -503,9 +504,8 @@ class TestDisableSegmentsFromIndexTask:
for doc_form in doc_forms:
# Update document form
document.doc_form = doc_form
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock the index processor factory
with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory:
@ -523,7 +523,7 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value
mock_factory.assert_called_with(doc_form)
def test_disable_segments_performance_timing(self, db_session_with_containers):
def test_disable_segments_performance_timing(self, db_session_with_containers: Session):
"""
Test that the task properly measures and logs performance timing.
@ -568,7 +568,7 @@ class TestDisableSegmentsFromIndexTask:
assert performance_log is not None
assert "0.5" in performance_log # Should log the execution time
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers):
def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session):
"""
Test that Redis cache is properly cleaned up for all segments.
@ -610,7 +610,7 @@ class TestDisableSegmentsFromIndexTask:
for expected_key in expected_keys:
assert expected_key in actual_calls
def test_disable_segments_database_session_cleanup(self, db_session_with_containers):
def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session):
"""
Test that database session is properly closed after task execution.
@ -643,7 +643,7 @@ class TestDisableSegmentsFromIndexTask:
assert result is None # Task should complete without returning a value
# Session lifecycle is managed by context manager; no explicit close assertion
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session):
"""
Test handling when empty segment IDs list is provided.
@ -669,7 +669,7 @@ class TestDisableSegmentsFromIndexTask:
# Redis should not be called when no segments are provided
mock_redis.delete.assert_not_called()
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers):
def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session):
"""
Test handling when some segment IDs are valid and others are invalid.

View File

@ -2,9 +2,9 @@ from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.rag.index_processor.constant.index_type import IndexStructureType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
@ -31,7 +31,9 @@ class TestEnableSegmentsToIndexTask:
"index_processor": mock_processor,
}
def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies):
def _create_test_dataset_and_document(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Helper method to create a test dataset and document for testing.
@ -51,15 +53,15 @@ class TestEnableSegmentsToIndexTask:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -68,8 +70,8 @@ class TestEnableSegmentsToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -81,8 +83,8 @@ class TestEnableSegmentsToIndexTask:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create document
document = Document(
@ -99,16 +101,16 @@ class TestEnableSegmentsToIndexTask:
enabled=True,
doc_form=IndexStructureType.PARAGRAPH_INDEX,
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property works correctly
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, document
def _create_test_segments(
self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed"
self, db_session_with_containers: Session, document, dataset, count=3, enabled=False, status="completed"
):
"""
Helper method to create test document segments.
@ -144,14 +146,14 @@ class TestEnableSegmentsToIndexTask:
status=status,
created_by=document.created_by,
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
return segments
def test_enable_segments_to_index_with_different_index_type(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segments indexing with different index types.
@ -169,10 +171,10 @@ class TestEnableSegmentsToIndexTask:
# Update document to use different index type
document.doc_form = IndexStructureType.QA_INDEX
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
# Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -204,7 +206,7 @@ class TestEnableSegmentsToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_enable_segments_to_index_dataset_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of non-existent dataset.
@ -229,7 +231,7 @@ class TestEnableSegmentsToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_document_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of non-existent document.
@ -256,7 +258,7 @@ class TestEnableSegmentsToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_invalid_document_status(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling of document with invalid status.
@ -284,12 +286,12 @@ class TestEnableSegmentsToIndexTask:
document.enabled = True
document.archived = False
document.indexing_status = "completed"
db.session.commit()
db_session_with_containers.commit()
# Set invalid status
for attr, value in status_attrs.items():
setattr(document, attr, value)
db.session.commit()
db_session_with_containers.commit()
# Create segments
segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -304,11 +306,11 @@ class TestEnableSegmentsToIndexTask:
# Clean up segments for next iteration
for segment in segments:
db.session.delete(segment)
db.session.commit()
db_session_with_containers.delete(segment)
db_session_with_containers.commit()
def test_enable_segments_to_index_segments_not_found(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test handling when no segments are found.
@ -338,7 +340,7 @@ class TestEnableSegmentsToIndexTask:
mock_external_service_dependencies["index_processor"].load.assert_not_called()
def test_enable_segments_to_index_with_parent_child_structure(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test segments indexing with parent-child structure.
@ -357,10 +359,10 @@ class TestEnableSegmentsToIndexTask:
# Update document to use parent-child index type
document.doc_form = IndexStructureType.PARENT_CHILD_INDEX
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure doc_form property reflects the updated document
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
# Create segments with mock child chunks
segments = self._create_test_segments(db_session_with_containers, document, dataset)
@ -410,7 +412,7 @@ class TestEnableSegmentsToIndexTask:
assert redis_client.exists(indexing_cache_key) == 0
def test_enable_segments_to_index_general_exception_handling(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test general exception handling during indexing process.
@ -443,7 +445,7 @@ class TestEnableSegmentsToIndexTask:
# Assert: Verify error handling
for segment in segments:
db.session.refresh(segment)
db_session_with_containers.refresh(segment)
assert segment.enabled is False
assert segment.status == "error"
assert segment.error is not None

View File

@ -2,8 +2,8 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from extensions.ext_database import db
from libs.email_i18n import EmailType
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task
@ -30,7 +30,7 @@ class TestMailAccountDeletionTask:
"email_service": mock_email_service,
}
def _create_test_account(self, db_session_with_containers):
def _create_test_account(self, db_session_with_containers: Session):
"""
Helper method to create a test account for testing.
@ -49,16 +49,16 @@ class TestMailAccountDeletionTask:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -67,12 +67,14 @@ class TestMailAccountDeletionTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
return account
def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies):
def test_send_deletion_success_task_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful account deletion success email sending.
@ -109,7 +111,7 @@ class TestMailAccountDeletionTask:
)
def test_send_deletion_success_task_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test account deletion success email when mail service is not initialized.
@ -132,7 +134,7 @@ class TestMailAccountDeletionTask:
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
def test_send_deletion_success_task_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test account deletion success email when email service raises exception.
@ -154,7 +156,7 @@ class TestMailAccountDeletionTask:
mock_external_service_dependencies["email_service"].send_email.assert_called_once()
def test_send_account_deletion_verification_code_success(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test successful account deletion verification code email sending.
@ -193,7 +195,7 @@ class TestMailAccountDeletionTask:
)
def test_send_account_deletion_verification_code_mail_not_initialized(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test account deletion verification code email when mail service is not initialized.
@ -217,7 +219,7 @@ class TestMailAccountDeletionTask:
mock_external_service_dependencies["email_service"].send_email.assert_not_called()
def test_send_account_deletion_verification_code_email_service_exception(
self, db_session_with_containers, mock_external_service_dependencies
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""
Test account deletion verification code email when email service raises exception.

View File

@ -4,11 +4,11 @@ from unittest.mock import patch
import pytest
from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Pipeline
from models.workflow import Workflow
@ -52,7 +52,7 @@ class TestRagPipelineRunTasks:
"delete_file": mock_delete_file,
}
def _create_test_pipeline_and_workflow(self, db_session_with_containers):
def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session):
"""
Helper method to create test pipeline and workflow for testing.
@ -71,15 +71,15 @@ class TestRagPipelineRunTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -88,8 +88,8 @@ class TestRagPipelineRunTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create workflow
workflow = Workflow(
@ -107,8 +107,8 @@ class TestRagPipelineRunTasks:
conversation_variables=[],
rag_pipeline_variables=[],
)
db.session.add(workflow)
db.session.commit()
db_session_with_containers.add(workflow)
db_session_with_containers.commit()
# Create pipeline
pipeline = Pipeline(
@ -119,14 +119,14 @@ class TestRagPipelineRunTasks:
created_by=account.id,
)
pipeline.id = str(uuid.uuid4())
db.session.add(pipeline)
db.session.commit()
db_session_with_containers.add(pipeline)
db_session_with_containers.commit()
# Refresh entities to ensure they're properly loaded
db.session.refresh(account)
db.session.refresh(tenant)
db.session.refresh(workflow)
db.session.refresh(pipeline)
db_session_with_containers.refresh(account)
db_session_with_containers.refresh(tenant)
db_session_with_containers.refresh(workflow)
db_session_with_containers.refresh(pipeline)
return account, tenant, pipeline, workflow
@ -209,7 +209,7 @@ class TestRagPipelineRunTasks:
return json.dumps(entities_data)
def test_priority_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test successful priority RAG pipeline run task execution.
@ -254,7 +254,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_rag_pipeline_run_task_success(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test successful regular RAG pipeline run task execution.
@ -299,7 +299,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_priority_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with waiting tasks in queue using real Redis.
@ -351,7 +351,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining
def test_rag_pipeline_run_task_legacy_compatibility(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility.
@ -419,7 +419,7 @@ class TestRagPipelineRunTasks:
redis_client.delete(legacy_task_key)
def test_rag_pipeline_run_task_with_waiting_tasks(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with waiting tasks in queue using real Redis.
@ -469,7 +469,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining
def test_priority_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in priority RAG pipeline run task using real Redis.
@ -526,7 +526,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_error_handling(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test error handling in regular RAG pipeline run task using real Redis.
@ -581,7 +581,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 0
def test_priority_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in priority RAG pipeline run task using real Redis.
@ -648,7 +648,7 @@ class TestRagPipelineRunTasks:
assert queue1._task_key != queue2._task_key
def test_rag_pipeline_run_task_tenant_isolation(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test tenant isolation in regular RAG pipeline run task using real Redis.
@ -713,7 +713,7 @@ class TestRagPipelineRunTasks:
assert queue1._task_key != queue2._task_key
def test_run_single_rag_pipeline_task_success(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
):
"""
Test successful run_single_rag_pipeline_task execution.
@ -748,7 +748,7 @@ class TestRagPipelineRunTasks:
assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity)
def test_run_single_rag_pipeline_task_entity_validation_error(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with invalid entity data.
@ -793,7 +793,7 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
def test_run_single_rag_pipeline_task_database_entity_not_found(
self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers
self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers
):
"""
Test run_single_rag_pipeline_task with non-existent database entities.
@ -838,7 +838,7 @@ class TestRagPipelineRunTasks:
mock_pipeline_generator.assert_not_called()
def test_priority_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test priority RAG pipeline run task with non-existent file.
@ -888,7 +888,7 @@ class TestRagPipelineRunTasks:
assert len(remaining_tasks) == 0
def test_rag_pipeline_run_task_file_not_found(
self, db_session_with_containers, mock_pipeline_generator, mock_file_service
self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service
):
"""
Test regular RAG pipeline run task with non-existent file.