refactor: migrate session.query to select API in end_user_service and small tasks (#34620)

This commit is contained in:
Renzo 2026-04-06 23:25:55 -05:00 committed by GitHub
parent 68bd29eda2
commit 1194957fde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 19 additions and 18 deletions

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from sqlalchemy import case
from sqlalchemy import case, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
@ -25,14 +25,14 @@ class EndUserService:
"""
with Session(db.engine, expire_on_commit=False) as session:
return (
session.query(EndUser)
return session.scalar(
select(EndUser)
.where(
EndUser.id == end_user_id,
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
)
.first()
.limit(1)
)
@classmethod
@ -57,8 +57,8 @@ class EndUserService:
with Session(db.engine, expire_on_commit=False) as session:
# Query with ORDER BY to prioritize exact type matches while maintaining backward compatibility
# This single query approach is more efficient than separate queries
end_user = (
session.query(EndUser)
end_user = session.scalar(
select(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id == app_id,
@ -68,7 +68,7 @@ class EndUserService:
# Prioritize records with matching type (0 = match, 1 = no match)
case((EndUser.type == type, 0), else_=1)
)
.first()
.limit(1)
)
if end_user:
@ -137,15 +137,15 @@ class EndUserService:
with Session(db.engine, expire_on_commit=False) as session:
# Fetch existing end users for all target apps in a single query
existing_end_users: list[EndUser] = (
session.query(EndUser)
.where(
EndUser.tenant_id == tenant_id,
EndUser.app_id.in_(unique_app_ids),
EndUser.session_id == user_id,
EndUser.type == type,
)
.all()
existing_end_users: list[EndUser] = list(
session.scalars(
select(EndUser).where(
EndUser.tenant_id == tenant_id,
EndUser.app_id.in_(unique_app_ids),
EndUser.session_id == user_id,
EndUser.type == type,
)
).all()
)
found_app_ids: set[str] = set()

View File

@ -285,7 +285,7 @@ class MCPToolManageService:
# Batch query all users to avoid N+1 problem
user_ids = {provider.user_id for provider in mcp_providers}
users = self._session.query(Account).where(Account.id.in_(user_ids)).all()
users = self._session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
user_name_map = {user.id: user.name for user in users}
return [

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -24,7 +25,7 @@ def disable_segment_from_index_task(segment_id: str):
start_at = time.perf_counter()
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
segment = session.scalar(select(DocumentSegment).where(DocumentSegment.id == segment_id).limit(1))
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return