refactor: select in console datasets document controller (#34019)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-03-24 13:57:38 +01:00 committed by GitHub
parent 542c1a14e0
commit e3c1112b15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 96 deletions

View File

@ -10,7 +10,7 @@ import sqlalchemy as sa
from flask import request, send_file from flask import request, send_file
from flask_restx import Resource, fields, marshal, marshal_with from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, func, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource):
raise Forbidden(str(e)) raise Forbidden(str(e))
# get the latest process rule # get the latest process rule
dataset_process_rule = ( dataset_process_rule = db.session.scalar(
db.session.query(DatasetProcessRule) select(DatasetProcessRule)
.where(DatasetProcessRule.dataset_id == document.dataset_id) .where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.limit(1) .limit(1)
.one_or_none()
) )
if dataset_process_rule: if dataset_process_rule:
mode = dataset_process_rule.mode mode = dataset_process_rule.mode
@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource):
if fetch: if fetch:
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.scalar(
.where( select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT, DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
) )
.count() or 0
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.scalar(
.where( select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT, DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
) )
.count() or 0
) )
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
if data_source_info and "upload_file_id" in data_source_info: if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file = ( file = db.session.scalar(
db.session.query(UploadFile) select(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first() .limit(1)
) )
# raise error if file not found # raise error if file not found
@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if not data_source_info: if not data_source_info:
continue continue
file_id = data_source_info["upload_file_id"] file_id = data_source_info["upload_file_id"]
file_detail = ( file_detail = db.session.scalar(
db.session.query(UploadFile) select(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id) .where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first() .limit(1)
) )
if file_detail is None: if file_detail is None:
@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.scalar(
.where( select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT, DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
) )
.count() or 0
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.scalar(
.where( select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT DocumentSegment.document_id == str(document.id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
) )
.count() or 0
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
document_dict = { document_dict = {
@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource):
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
completed_segments = ( completed_segments = (
db.session.query(DocumentSegment) db.session.scalar(
.where( select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT, DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
) )
.count() or 0
) )
total_segments = ( total_segments = (
db.session.query(DocumentSegment) db.session.scalar(
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) select(func.count(DocumentSegment.id)).where(
.count() DocumentSegment.document_id == str(document_id),
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
or 0
) )
# Create a dictionary with document attributes and additional fields # Create a dictionary with document attributes and additional fields
@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
document = DocumentService.get_document(dataset.id, document_id) document = DocumentService.get_document(dataset.id, document_id)
if not document: if not document:
raise NotFound("Document not found.") raise NotFound("Document not found.")
log = ( log = db.session.scalar(
db.session.query(DocumentPipelineExecutionLog) select(DocumentPipelineExecutionLog)
.filter_by(document_id=document_id) .where(DocumentPipelineExecutionLog.document_id == document_id)
.order_by(DocumentPipelineExecutionLog.created_at.desc()) .order_by(DocumentPipelineExecutionLog.created_at.desc())
.first() .limit(1)
) )
if not log: if not log:
return { return {

View File

@ -2,6 +2,8 @@ from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import ParamSpec, TypeVar from typing import ParamSpec, TypeVar
from sqlalchemy import select
from controllers.console.datasets.error import PipelineNotFoundError from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db from extensions.ext_database import db
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
del kwargs["pipeline_id"] del kwargs["pipeline_id"]
pipeline = ( pipeline = db.session.scalar(
db.session.query(Pipeline) select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
.first()
) )
if not pipeline: if not pipeline:

View File

@ -140,8 +140,8 @@ class TestDatasetDocumentListApi:
return_value=pagination, return_value=pagination,
), ),
patch( patch(
"controllers.console.datasets.datasets_document.db.session.query", "controllers.console.datasets.datasets_document.db.session.scalar",
return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), return_value=2,
), ),
patch( patch(
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi:
return_value=MagicMock(), return_value=MagicMock(),
), ),
patch( patch(
"controllers.console.datasets.datasets_document.db.session.query", "controllers.console.datasets.datasets_document.db.session.scalar",
return_value=MagicMock( return_value=log,
filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log))
),
), ),
): ):
response, status = method(api, "ds-1", "doc-1") response, status = method(api, "ds-1", "doc-1")
@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi:
dataset_process_rule=None, dataset_process_rule=None,
) )
query_mock = MagicMock()
query_mock.where.return_value.first.return_value = None
with ( with (
app.test_request_context("/"), app.test_request_context("/"),
patch.object(api, "get_document", return_value=document), patch.object(api, "get_document", return_value=document),
patch( patch(
"controllers.console.datasets.datasets_document.db.session.query", "controllers.console.datasets.datasets_document.db.session.scalar",
return_value=query_mock, return_value=None,
), ),
): ):
with pytest.raises(NotFound): with pytest.raises(NotFound):
@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi:
app.test_request_context("/"), app.test_request_context("/"),
patch.object(api, "get_document", return_value=document), patch.object(api, "get_document", return_value=document),
patch( patch(
"controllers.console.datasets.datasets_document.db.session.query", "controllers.console.datasets.datasets_document.db.session.scalar",
return_value=MagicMock( return_value=upload_file,
where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file)))
),
), ),
patch( patch(
"controllers.console.datasets.datasets_document.ExtractSetting", "controllers.console.datasets.datasets_document.ExtractSetting",
@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases:
return_value=None, return_value=None,
), ),
patch( patch(
"controllers.console.datasets.datasets_document.db.session.query", "controllers.console.datasets.datasets_document.db.session.scalar",
return_value=MagicMock( return_value=process_rule,
where=lambda *a: MagicMock(
order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule))
)
),
), ),
): ):
result = method(api) result = method(api)
@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases:
app.test_request_context("/"), app.test_request_context("/"),
patch.object(api, "get_document", return_value=document), patch.object(api, "get_document", return_value=document),
patch( patch(
"controllers.console.datasets.datasets_document.db.session.query", "controllers.console.datasets.datasets_document.db.session.scalar",
return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), return_value=upload_file,
), ),
patch( patch(
"controllers.console.datasets.datasets_document.ExtractSetting", "controllers.console.datasets.datasets_document.ExtractSetting",

View File

@ -26,12 +26,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"), return_value=(Mock(), "tenant-1"),
) )
mock_query = Mock()
mock_query.where.return_value.first.return_value = None
mocker.patch( mocker.patch(
"controllers.console.datasets.wraps.db.session.query", "controllers.console.datasets.wraps.db.session.scalar",
return_value=mock_query, return_value=None,
) )
with pytest.raises(PipelineNotFoundError): with pytest.raises(PipelineNotFoundError):
@ -51,12 +48,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"), return_value=(Mock(), "tenant-1"),
) )
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch( mocker.patch(
"controllers.console.datasets.wraps.db.session.query", "controllers.console.datasets.wraps.db.session.scalar",
return_value=mock_query, return_value=pipeline,
) )
result = dummy_view(pipeline_id="pipeline-1") result = dummy_view(pipeline_id="pipeline-1")
@ -76,12 +70,9 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"), return_value=(Mock(), "tenant-1"),
) )
mock_query = Mock()
mock_query.where.return_value.first.return_value = pipeline
mocker.patch( mocker.patch(
"controllers.console.datasets.wraps.db.session.query", "controllers.console.datasets.wraps.db.session.scalar",
return_value=mock_query, return_value=pipeline,
) )
result = dummy_view(pipeline_id="pipeline-1") result = dummy_view(pipeline_id="pipeline-1")
@ -100,18 +91,15 @@ class TestGetRagPipeline:
return_value=(Mock(), "tenant-1"), return_value=(Mock(), "tenant-1"),
) )
def where_side_effect(*args, **kwargs): mock_scalar = mocker.patch(
assert args[0].right.value == "123" "controllers.console.datasets.wraps.db.session.scalar",
return Mock(first=lambda: pipeline) return_value=pipeline,
mock_query = Mock()
mock_query.where.side_effect = where_side_effect
mocker.patch(
"controllers.console.datasets.wraps.db.session.query",
return_value=mock_query,
) )
result = dummy_view(pipeline_id=123) result = dummy_view(pipeline_id=123)
assert result is pipeline assert result is pipeline
# Verify the pipeline_id was cast to string in the where clause
stmt = mock_scalar.call_args[0][0]
where_clauses = stmt.whereclause.clauses
assert where_clauses[0].right.value == "123"