refactor: migrate session.query to select API in small task files batch (#34684)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-04-07 17:58:23 -05:00 committed by GitHub
parent 5aa2524d33
commit cb55176612
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 12 deletions

View File

@ -8,7 +8,7 @@ import click
import pandas as pd
from celery import shared_task
from graphon.model_runtime.entities.model_entities import ModelType
from sqlalchemy import func
from sqlalchemy import func, select
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
@ -140,10 +140,8 @@ def batch_create_segment_to_index_task(
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
max_position = session.scalar(
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document_config["id"])
)
segment_document = DocumentSegment(
tenant_id=tenant_id,

View File

@ -1,6 +1,7 @@
import logging
from celery import shared_task
from sqlalchemy import select
from configs import dify_config
from core.db.session_factory import session_factory
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
with session_factory.create_session() as session:
account = session.query(Account).where(Account.id == account_id).first()
account = session.scalar(select(Account).where(Account.id == account_id).limit(1))
try:
if dify_config.BILLING_ENABLED:
BillingService.delete_account(account_id)

View File

@ -82,7 +82,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return

View File

@ -26,9 +26,6 @@ def mock_db_session():
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@ -49,12 +46,12 @@ def mock_deps():
def _set_account_found(mock_db_session, email: str = "user@example.com"):
account = SimpleNamespace(email=email)
mock_db_session.query.return_value.where.return_value.first.return_value = account
mock_db_session.scalar.return_value = account
return account
def _set_account_missing(mock_db_session):
mock_db_session.query.return_value.where.return_value.first.return_value = None
mock_db_session.scalar.return_value = None
class TestDeleteAccountTask: