mirror of
https://github.com/langgenius/dify.git
synced 2026-03-26 13:51:05 +08:00
refactor: select in console datasets segments and API key controllers (#34027)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
b4e541e11a
commit
4c32acf857
@ -3,7 +3,7 @@ from typing import Any, cast
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
import services
|
||||
@ -738,20 +738,23 @@ class DatasetIndexingStatusApi(Resource):
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
total_segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document.id),
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
.count()
|
||||
or 0
|
||||
)
|
||||
# Create a dictionary with document attributes and additional fields
|
||||
document_dict = {
|
||||
@ -802,9 +805,12 @@ class DatasetApiKeyApi(Resource):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
current_key_count = (
|
||||
db.session.query(ApiToken)
|
||||
.where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id)
|
||||
.count()
|
||||
db.session.scalar(
|
||||
select(func.count(ApiToken.id)).where(
|
||||
ApiToken.type == self.resource_type, ApiToken.tenant_id == current_tenant_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
if current_key_count >= self.max_keys:
|
||||
@ -839,14 +845,14 @@ class DatasetApiDeleteApi(Resource):
|
||||
def delete(self, api_key_id):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
api_key_id = str(api_key_id)
|
||||
key = (
|
||||
db.session.query(ApiToken)
|
||||
key = db.session.scalar(
|
||||
select(ApiToken)
|
||||
.where(
|
||||
ApiToken.tenant_id == current_tenant_id,
|
||||
ApiToken.type == self.resource_type,
|
||||
ApiToken.id == api_key_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if key is None:
|
||||
@ -857,7 +863,7 @@ class DatasetApiDeleteApi(Resource):
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.delete(key)
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@ -401,10 +401,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -447,10 +447,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -494,7 +494,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
payload = BatchImportPayload.model_validate(console_ns.payload or {})
|
||||
upload_file_id = payload.upload_file_id
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
upload_file = db.session.scalar(select(UploadFile).where(UploadFile.id == upload_file_id).limit(1))
|
||||
if not upload_file:
|
||||
raise NotFound("UploadFile not found.")
|
||||
|
||||
@ -559,10 +559,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -616,10 +616,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -666,10 +666,10 @@ class ChildChunkAddApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
@ -714,24 +714,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
@ -771,24 +771,24 @@ class ChildChunkUpdateApi(Resource):
|
||||
raise NotFound("Document not found.")
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
segment = db.session.scalar(
|
||||
select(DocumentSegment)
|
||||
.where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
# check child chunk
|
||||
child_chunk_id = str(child_chunk_id)
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
child_chunk = db.session.scalar(
|
||||
select(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.id == str(child_chunk_id),
|
||||
ChildChunk.tenant_id == current_tenant_id,
|
||||
ChildChunk.segment_id == segment.id,
|
||||
ChildChunk.document_id == document_id,
|
||||
)
|
||||
.first()
|
||||
.limit(1)
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
@ -1476,8 +1476,8 @@ class TestDatasetIndexingStatusApi:
|
||||
return_value=MagicMock(all=lambda: [document]),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "dataset-1")
|
||||
@ -1526,13 +1526,6 @@ class TestDatasetIndexingStatusApi:
|
||||
document.error = None
|
||||
document.stopped_at = None
|
||||
|
||||
# First count = completed segments, second = total segments
|
||||
query_mock = MagicMock()
|
||||
query_mock.where.side_effect = [
|
||||
MagicMock(count=lambda: 2),
|
||||
MagicMock(count=lambda: 5),
|
||||
]
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
@ -1544,8 +1537,8 @@ class TestDatasetIndexingStatusApi:
|
||||
return_value=MagicMock(all=lambda: [document]),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=query_mock,
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
side_effect=[2, 5],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "dataset-1")
|
||||
@ -1591,8 +1584,8 @@ class TestDatasetApiKeyApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=3,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
|
||||
@ -1625,8 +1618,8 @@ class TestDatasetApiKeyApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=10,
|
||||
),
|
||||
):
|
||||
with pytest.raises(BadRequest) as exc_info:
|
||||
@ -1653,8 +1646,8 @@ class TestDatasetApiDeleteApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=mock_key,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.commit",
|
||||
@ -1681,8 +1674,8 @@ class TestDatasetApiDeleteApi:
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets.db.session.query",
|
||||
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
|
||||
"controllers.console.datasets.datasets.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
|
||||
@ -526,8 +526,8 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -621,8 +621,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.setnx",
|
||||
@ -706,8 +706,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
@ -738,8 +738,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
@ -770,8 +770,8 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.setnx",
|
||||
@ -831,8 +831,8 @@ class TestChildChunkAddApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -880,8 +880,8 @@ class TestChildChunkAddApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=segment,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -924,11 +924,8 @@ class TestChildChunkUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
side_effect=[
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
|
||||
],
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -970,11 +967,8 @@ class TestChildChunkUpdateApi:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
side_effect=[
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)),
|
||||
MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)),
|
||||
],
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
side_effect=[segment, child_chunk],
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
@ -1180,8 +1174,8 @@ class TestSegmentOperationCases:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
@ -1215,8 +1209,8 @@ class TestSegmentOperationCases:
|
||||
return_value=document,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.db.session.query",
|
||||
return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)),
|
||||
"controllers.console.datasets.datasets_segments.db.session.scalar",
|
||||
return_value=upload_file,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user