From 1194957fde945601be9ec99cadc8c9f8627ededd Mon Sep 17 00:00:00 2001 From: Renzo <170978465+RenzoMXD@users.noreply.github.com> Date: Mon, 6 Apr 2026 23:25:55 -0500 Subject: [PATCH] refactor: migrate session.query to select API in end_user_service and small tasks (#34620) --- api/services/end_user_service.py | 32 +++++++++---------- .../tools/mcp_tools_manage_service.py | 2 +- api/tasks/disable_segment_from_index_task.py | 3 +- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index 326f46780d..29ada270ec 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping -from sqlalchemy import case +from sqlalchemy import case, select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom @@ -25,14 +25,14 @@ class EndUserService: """ with Session(db.engine, expire_on_commit=False) as session: - return ( - session.query(EndUser) + return session.scalar( + select(EndUser) .where( EndUser.id == end_user_id, EndUser.tenant_id == tenant_id, EndUser.app_id == app_id, ) - .first() + .limit(1) ) @classmethod @@ -57,8 +57,8 @@ class EndUserService: with Session(db.engine, expire_on_commit=False) as session: # Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility # This single query approach is more efficient than separate queries - end_user = ( - session.query(EndUser) + end_user = session.scalar( + select(EndUser) .where( EndUser.tenant_id == tenant_id, EndUser.app_id == app_id, @@ -68,7 +68,7 @@ class EndUserService: # Prioritize records with matching type (0 = match, 1 = no match) case((EndUser.type == type, 0), else_=1) ) - .first() + .limit(1) ) if end_user: @@ -137,15 +137,15 @@ class EndUserService: with Session(db.engine, expire_on_commit=False) as session: # Fetch existing end users for all target apps in a single query - existing_end_users: list[EndUser] = ( - session.query(EndUser) - .where( - EndUser.tenant_id == tenant_id, - EndUser.app_id.in_(unique_app_ids), - EndUser.session_id == user_id, - EndUser.type == type, - ) - .all() + existing_end_users: list[EndUser] = list( + session.scalars( + select(EndUser).where( + EndUser.tenant_id == tenant_id, + EndUser.app_id.in_(unique_app_ids), + EndUser.session_id == user_id, + EndUser.type == type, + ) + ).all() ) found_app_ids: set[str] = set() diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index deb26438a8..690b06ea7d 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -285,7 +285,7 @@ class MCPToolManageService: # Batch query all users to avoid N+1 problem user_ids = {provider.user_id for provider in mcp_providers} - users = self._session.query(Account).where(Account.id.in_(user_ids)).all() + users = self._session.scalars(select(Account).where(Account.id.in_(user_ids))).all() user_name_map = {user.id: user.name for user in users} return [ diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py index bc45171623..dd1a40844b 100644 --- a/api/tasks/disable_segment_from_index_task.py +++ b/api/tasks/disable_segment_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import select from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -24,7 +25,7 @@ def disable_segment_from_index_task(segment_id: str): start_at = time.perf_counter() with session_factory.create_session() as session: - segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first() + segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1)) if not segment: logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return