diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 20335d9b9f..77feea47a2 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -8,7 +8,7 @@ import click import pandas as pd from celery import shared_task from graphon.model_runtime.entities.model_entities import ModelType -from sqlalchemy import func +from sqlalchemy import func, select from core.db.session_factory import session_factory from core.model_manager import ModelManager @@ -140,10 +140,8 @@ def batch_create_segment_to_index_task( content = segment["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) - max_position = ( - session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == document_config["id"]) - .scalar() + max_position = session.scalar( + select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document_config["id"]) ) segment_document = DocumentSegment( tenant_id=tenant_id, diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py index ecf6f9cb39..55a99dde7a 100644 --- a/api/tasks/delete_account_task.py +++ b/api/tasks/delete_account_task.py @@ -1,6 +1,7 @@ import logging from celery import shared_task +from sqlalchemy import select from configs import dify_config from core.db.session_factory import session_factory @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") def delete_account_task(account_id): with session_factory.create_session() as session: - account = session.query(Account).where(Account.id == account_id).first() + account = session.scalar(select(Account).where(Account.id == account_id).limit(1)) try: if dify_config.BILLING_ENABLED: BillingService.delete_account(account_id) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 13c651753f..6bc58bdf9c 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -82,7 +82,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st with session_factory.create_session() as session: try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if dataset is None: logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red")) return diff --git a/api/tests/unit_tests/tasks/test_delete_account_task.py b/api/tests/unit_tests/tasks/test_delete_account_task.py index 8a12a4a169..f949c13158 100644 --- a/api/tests/unit_tests/tasks/test_delete_account_task.py +++ b/api/tests/unit_tests/tasks/test_delete_account_task.py @@ -26,9 +26,6 @@ def mock_db_session(): cm.__exit__.return_value = None mock_sf.create_session.return_value = cm - query = MagicMock() - session.query.return_value = query - query.where.return_value = query yield session @@ -49,12 +46,12 @@ def mock_deps(): def _set_account_found(mock_db_session, email: str = "user@example.com"): account = SimpleNamespace(email=email) - mock_db_session.query.return_value.where.return_value.first.return_value = account + mock_db_session.scalar.return_value = account return account def _set_account_missing(mock_db_session): - mock_db_session.query.return_value.where.return_value.first.return_value = None + mock_db_session.scalar.return_value = None class TestDeleteAccountTask: