From ac8bd12609d9f466d8279fad50539aaf4ea194b4 Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Mon, 6 Apr 2026 23:13:22 -0500 Subject: [PATCH] refactor: migrate session.query to select API in small task files (#34617) --- api/tasks/annotation/batch_import_annotations_task.py | 9 ++++++--- api/tasks/annotation/disable_annotation_reply_task.py | 8 +++++--- api/tasks/annotation/enable_annotation_reply_task.py | 8 +++++--- api/tasks/enable_segment_to_index_task.py | 3 ++- api/tasks/recover_document_indexing_task.py | 5 ++++- api/tasks/trigger_subscription_refresh_tasks.py | 7 ++++++- 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index c734e1321b..89844ef44b 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from werkzeug.exceptions import NotFound from core.db.session_factory import session_factory @@ -35,7 +36,9 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: with session_factory.create_session() as session: # get app info - app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1) + ) if app: try: @@ -53,8 +56,8 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: ) documents.append(document) # if annotation reply is enabled , batch add annotations' index - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if app_annotation_setting: diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py index 41cf7ccbf6..6a9b52e7e5 100644 --- a/api/tasks/annotation/disable_annotation_reply_task.py +++ b/api/tasks/annotation/disable_annotation_reply_task.py @@ -24,14 +24,16 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str): start_at = time.perf_counter() # get app info with session_factory.create_session() as session: - app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1) + ) annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id))) if not app: logger.info(click.style(f"App not found: {app_id}", fg="red")) return - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + app_annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if not app_annotation_setting: diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 2c07fe0f31..4cbca13a92 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -36,7 +36,9 @@ def enable_annotation_reply_task( start_at = time.perf_counter() # get app info with session_factory.create_session() as session: - app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first() + app = session.scalar( + select(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").limit(1) + ) if not app: logger.info(click.style(f"App not found: {app_id}", fg="red")) @@ -51,8 +53,8 @@ def enable_annotation_reply_task( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) - annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() + annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) ) if annotation_setting: if dataset_collection_binding.id != annotation_setting.collection_binding_id: diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 5ad17d75d4..8334ca2588 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -29,7 +30,7 @@ def enable_segment_to_index_task(segment_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if not segment: logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index af72023da1..73b121961c 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -24,7 +25,9 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() + document = session.scalar( + select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) + ) if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) diff --git a/api/tasks/trigger_subscription_refresh_tasks.py b/api/tasks/trigger_subscription_refresh_tasks.py index 7698a1a6b8..1daf8f302c 100644 --- a/api/tasks/trigger_subscription_refresh_tasks.py +++ b/api/tasks/trigger_subscription_refresh_tasks.py @@ -4,6 +4,7 @@ from collections.abc import Mapping from typing import Any from celery import shared_task +from sqlalchemy import select from sqlalchemy.orm import Session from configs import dify_config @@ -22,7 +23,11 @@ def _now_ts() -> int: def _load_subscription(session: Session, tenant_id: str, subscription_id: str) -> TriggerSubscription | None: - return session.query(TriggerSubscription).filter_by(tenant_id=tenant_id, id=subscription_id).first() + return session.scalar( + select(TriggerSubscription) + .where(TriggerSubscription.tenant_id == tenant_id, TriggerSubscription.id == subscription_id) + .limit(1) + ) def _refresh_oauth_if_expired(tenant_id: str, subscription: TriggerSubscription, now: int) -> None: