refactor: migrate session.query to select API in sync task and services (#34619)

This commit is contained in:
Renzo 2026-04-06 23:23:14 -05:00 committed by GitHub
parent f67a811f7f
commit 68bd29eda2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,6 @@
import base64
from sqlalchemy import Engine
from sqlalchemy import Engine, select
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
@ -22,8 +22,8 @@ class AttachmentService:
raise AssertionError("must be a sessionmaker or an Engine.")
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = self._session_maker(expire_on_commit=False).scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
if not upload_file:
raise NotFound("File not found")

View File

@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
@ -56,24 +57,24 @@ class PluginParameterService:
# fetch credentials from db
with Session(db.engine) as session:
if credential_id:
db_record = (
session.query(BuiltinToolProvider)
db_record = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.id == credential_id,
)
.first()
.limit(1)
)
else:
db_record = (
session.query(BuiltinToolProvider)
db_record = session.scalar(
select(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
.limit(1)
)
if db_record is None:

View File

@ -29,7 +29,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
start_at = time.perf_counter()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if dataset is None:
raise ValueError("Dataset not found")
@ -45,8 +45,8 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"your subscription."
)
except Exception as e:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if document:
document.indexing_status = IndexingStatus.ERROR
@ -58,7 +58,9 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
return
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return