mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 00:18:28 +08:00
refactor: select in console datasets document controller (#34019)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
542c1a14e0
commit
e3c1112b15
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
|||||||
from flask import request, send_file
|
from flask import request, send_file
|
||||||
from flask_restx import Resource, fields, marshal, marshal_with
|
from flask_restx import Resource, fields, marshal, marshal_with
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import asc, desc, select
|
from sqlalchemy import asc, desc, func, select
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
@ -211,12 +211,11 @@ class GetProcessRuleApi(Resource):
|
|||||||
raise Forbidden(str(e))
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
# get the latest process rule
|
# get the latest process rule
|
||||||
dataset_process_rule = (
|
dataset_process_rule = db.session.scalar(
|
||||||
db.session.query(DatasetProcessRule)
|
select(DatasetProcessRule)
|
||||||
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
.where(DatasetProcessRule.dataset_id == document.dataset_id)
|
||||||
.order_by(DatasetProcessRule.created_at.desc())
|
.order_by(DatasetProcessRule.created_at.desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
.one_or_none()
|
|
||||||
)
|
)
|
||||||
if dataset_process_rule:
|
if dataset_process_rule:
|
||||||
mode = dataset_process_rule.mode
|
mode = dataset_process_rule.mode
|
||||||
@ -330,21 +329,23 @@ class DatasetDocumentListApi(Resource):
|
|||||||
if fetch:
|
if fetch:
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = (
|
completed_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.scalar(
|
||||||
.where(
|
select(func.count(DocumentSegment.id)).where(
|
||||||
DocumentSegment.completed_at.isnot(None),
|
DocumentSegment.completed_at.isnot(None),
|
||||||
DocumentSegment.document_id == str(document.id),
|
DocumentSegment.document_id == str(document.id),
|
||||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.count()
|
or 0
|
||||||
)
|
)
|
||||||
total_segments = (
|
total_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.scalar(
|
||||||
.where(
|
select(func.count(DocumentSegment.id)).where(
|
||||||
DocumentSegment.document_id == str(document.id),
|
DocumentSegment.document_id == str(document.id),
|
||||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.count()
|
or 0
|
||||||
)
|
)
|
||||||
document.completed_segments = completed_segments
|
document.completed_segments = completed_segments
|
||||||
document.total_segments = total_segments
|
document.total_segments = total_segments
|
||||||
@ -521,10 +522,10 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
|||||||
if data_source_info and "upload_file_id" in data_source_info:
|
if data_source_info and "upload_file_id" in data_source_info:
|
||||||
file_id = data_source_info["upload_file_id"]
|
file_id = data_source_info["upload_file_id"]
|
||||||
|
|
||||||
file = (
|
file = db.session.scalar(
|
||||||
db.session.query(UploadFile)
|
select(UploadFile)
|
||||||
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# raise error if file not found
|
# raise error if file not found
|
||||||
@ -586,10 +587,10 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||||||
if not data_source_info:
|
if not data_source_info:
|
||||||
continue
|
continue
|
||||||
file_id = data_source_info["upload_file_id"]
|
file_id = data_source_info["upload_file_id"]
|
||||||
file_detail = (
|
file_detail = db.session.scalar(
|
||||||
db.session.query(UploadFile)
|
select(UploadFile)
|
||||||
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_detail is None:
|
if file_detail is None:
|
||||||
@ -672,20 +673,23 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
|||||||
documents_status = []
|
documents_status = []
|
||||||
for document in documents:
|
for document in documents:
|
||||||
completed_segments = (
|
completed_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.scalar(
|
||||||
.where(
|
select(func.count(DocumentSegment.id)).where(
|
||||||
DocumentSegment.completed_at.isnot(None),
|
DocumentSegment.completed_at.isnot(None),
|
||||||
DocumentSegment.document_id == str(document.id),
|
DocumentSegment.document_id == str(document.id),
|
||||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.count()
|
or 0
|
||||||
)
|
)
|
||||||
total_segments = (
|
total_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.scalar(
|
||||||
.where(
|
select(func.count(DocumentSegment.id)).where(
|
||||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
DocumentSegment.document_id == str(document.id),
|
||||||
|
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.count()
|
or 0
|
||||||
)
|
)
|
||||||
# Create a dictionary with document attributes and additional fields
|
# Create a dictionary with document attributes and additional fields
|
||||||
document_dict = {
|
document_dict = {
|
||||||
@ -723,18 +727,23 @@ class DocumentIndexingStatusApi(DocumentResource):
|
|||||||
document = self.get_document(dataset_id, document_id)
|
document = self.get_document(dataset_id, document_id)
|
||||||
|
|
||||||
completed_segments = (
|
completed_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.scalar(
|
||||||
.where(
|
select(func.count(DocumentSegment.id)).where(
|
||||||
DocumentSegment.completed_at.isnot(None),
|
DocumentSegment.completed_at.isnot(None),
|
||||||
DocumentSegment.document_id == str(document_id),
|
DocumentSegment.document_id == str(document_id),
|
||||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
.count()
|
or 0
|
||||||
)
|
)
|
||||||
total_segments = (
|
total_segments = (
|
||||||
db.session.query(DocumentSegment)
|
db.session.scalar(
|
||||||
.where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT)
|
select(func.count(DocumentSegment.id)).where(
|
||||||
.count()
|
DocumentSegment.document_id == str(document_id),
|
||||||
|
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a dictionary with document attributes and additional fields
|
# Create a dictionary with document attributes and additional fields
|
||||||
@ -1258,11 +1267,11 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
|||||||
document = DocumentService.get_document(dataset.id, document_id)
|
document = DocumentService.get_document(dataset.id, document_id)
|
||||||
if not document:
|
if not document:
|
||||||
raise NotFound("Document not found.")
|
raise NotFound("Document not found.")
|
||||||
log = (
|
log = db.session.scalar(
|
||||||
db.session.query(DocumentPipelineExecutionLog)
|
select(DocumentPipelineExecutionLog)
|
||||||
.filter_by(document_id=document_id)
|
.where(DocumentPipelineExecutionLog.document_id == document_id)
|
||||||
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
.order_by(DocumentPipelineExecutionLog.created_at.desc())
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not log:
|
if not log:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -2,6 +2,8 @@ from collections.abc import Callable
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from controllers.console.datasets.error import PipelineNotFoundError
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_account_with_tenant
|
from libs.login import current_account_with_tenant
|
||||||
@ -24,10 +26,8 @@ def get_rag_pipeline(view_func: Callable[P, R]):
|
|||||||
|
|
||||||
del kwargs["pipeline_id"]
|
del kwargs["pipeline_id"]
|
||||||
|
|
||||||
pipeline = (
|
pipeline = db.session.scalar(
|
||||||
db.session.query(Pipeline)
|
select(Pipeline).where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id).limit(1)
|
||||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_tenant_id)
|
|
||||||
.first()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not pipeline:
|
if not pipeline:
|
||||||
|
|||||||
@ -140,8 +140,8 @@ class TestDatasetDocumentListApi:
|
|||||||
return_value=pagination,
|
return_value=pagination,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.db.session.query",
|
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)),
|
return_value=2,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
|
"controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status",
|
||||||
@ -700,10 +700,8 @@ class TestDocumentPipelineExecutionLogApi:
|
|||||||
return_value=MagicMock(),
|
return_value=MagicMock(),
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.db.session.query",
|
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||||
return_value=MagicMock(
|
return_value=log,
|
||||||
filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log))
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
response, status = method(api, "ds-1", "doc-1")
|
response, status = method(api, "ds-1", "doc-1")
|
||||||
@ -827,15 +825,12 @@ class TestDocumentIndexingEstimateApi:
|
|||||||
dataset_process_rule=None,
|
dataset_process_rule=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
query_mock = MagicMock()
|
|
||||||
query_mock.where.return_value.first.return_value = None
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch.object(api, "get_document", return_value=document),
|
patch.object(api, "get_document", return_value=document),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.db.session.query",
|
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||||
return_value=query_mock,
|
return_value=None,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
with pytest.raises(NotFound):
|
with pytest.raises(NotFound):
|
||||||
@ -863,10 +858,8 @@ class TestDocumentIndexingEstimateApi:
|
|||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch.object(api, "get_document", return_value=document),
|
patch.object(api, "get_document", return_value=document),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.db.session.query",
|
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||||
return_value=MagicMock(
|
return_value=upload_file,
|
||||||
where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file)))
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.ExtractSetting",
|
"controllers.console.datasets.datasets_document.ExtractSetting",
|
||||||
@ -1239,12 +1232,8 @@ class TestDocumentPermissionCases:
|
|||||||
return_value=None,
|
return_value=None,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.db.session.query",
|
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||||
return_value=MagicMock(
|
return_value=process_rule,
|
||||||
where=lambda *a: MagicMock(
|
|
||||||
order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule))
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
result = method(api)
|
result = method(api)
|
||||||
@ -1364,8 +1353,8 @@ class TestDocumentIndexingEdgeCases:
|
|||||||
app.test_request_context("/"),
|
app.test_request_context("/"),
|
||||||
patch.object(api, "get_document", return_value=document),
|
patch.object(api, "get_document", return_value=document),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.db.session.query",
|
"controllers.console.datasets.datasets_document.db.session.scalar",
|
||||||
return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)),
|
return_value=upload_file,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"controllers.console.datasets.datasets_document.ExtractSetting",
|
"controllers.console.datasets.datasets_document.ExtractSetting",
|
||||||
|
|||||||
@ -26,12 +26,9 @@ class TestGetRagPipeline:
|
|||||||
return_value=(Mock(), "tenant-1"),
|
return_value=(Mock(), "tenant-1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_query = Mock()
|
|
||||||
mock_query.where.return_value.first.return_value = None
|
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"controllers.console.datasets.wraps.db.session.query",
|
"controllers.console.datasets.wraps.db.session.scalar",
|
||||||
return_value=mock_query,
|
return_value=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(PipelineNotFoundError):
|
with pytest.raises(PipelineNotFoundError):
|
||||||
@ -51,12 +48,9 @@ class TestGetRagPipeline:
|
|||||||
return_value=(Mock(), "tenant-1"),
|
return_value=(Mock(), "tenant-1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_query = Mock()
|
|
||||||
mock_query.where.return_value.first.return_value = pipeline
|
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"controllers.console.datasets.wraps.db.session.query",
|
"controllers.console.datasets.wraps.db.session.scalar",
|
||||||
return_value=mock_query,
|
return_value=pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = dummy_view(pipeline_id="pipeline-1")
|
result = dummy_view(pipeline_id="pipeline-1")
|
||||||
@ -76,12 +70,9 @@ class TestGetRagPipeline:
|
|||||||
return_value=(Mock(), "tenant-1"),
|
return_value=(Mock(), "tenant-1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_query = Mock()
|
|
||||||
mock_query.where.return_value.first.return_value = pipeline
|
|
||||||
|
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"controllers.console.datasets.wraps.db.session.query",
|
"controllers.console.datasets.wraps.db.session.scalar",
|
||||||
return_value=mock_query,
|
return_value=pipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = dummy_view(pipeline_id="pipeline-1")
|
result = dummy_view(pipeline_id="pipeline-1")
|
||||||
@ -100,18 +91,15 @@ class TestGetRagPipeline:
|
|||||||
return_value=(Mock(), "tenant-1"),
|
return_value=(Mock(), "tenant-1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def where_side_effect(*args, **kwargs):
|
mock_scalar = mocker.patch(
|
||||||
assert args[0].right.value == "123"
|
"controllers.console.datasets.wraps.db.session.scalar",
|
||||||
return Mock(first=lambda: pipeline)
|
return_value=pipeline,
|
||||||
|
|
||||||
mock_query = Mock()
|
|
||||||
mock_query.where.side_effect = where_side_effect
|
|
||||||
|
|
||||||
mocker.patch(
|
|
||||||
"controllers.console.datasets.wraps.db.session.query",
|
|
||||||
return_value=mock_query,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = dummy_view(pipeline_id=123)
|
result = dummy_view(pipeline_id=123)
|
||||||
|
|
||||||
assert result is pipeline
|
assert result is pipeline
|
||||||
|
# Verify the pipeline_id was cast to string in the where clause
|
||||||
|
stmt = mock_scalar.call_args[0][0]
|
||||||
|
where_clauses = stmt.whereclause.clauses
|
||||||
|
assert where_clauses[0].right.value == "123"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user