diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 235d147559..5f681c238f 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -170,46 +170,47 @@ class DatasetDocumentListApi(Resource): raise Forbidden(str(e)) with Session(db.engine) as session: - query = session.execute( - select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) - ).all() - - if search: - search = f"%{search}%" - query = query.filter(Document.name.like(search)) - - if sort.startswith("-"): - sort_logic = desc - sort = sort[1:] - else: - sort_logic = asc - - if sort == "hit_count": - sub_query = ( - db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) - .group_by(DocumentSegment.document_id) - .subquery() + query = session.query(Document).filter_by( + dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == "created_at": - query = query.order_by( - sort_logic(Document.created_at), - sort_logic(Document.position), - ) - else: - query = query.order_by( - desc(Document.created_at), - desc(Document.position), - ) + if search: + search = f"%{search}%" + query = query.filter(Document.name.like(search)) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) - documents = paginated_documents.items - if fetch: - with Session(db.engine) as session: + if sort.startswith("-"): + sort_logic = desc + sort = sort[1:] + else: + sort_logic = asc + + if sort == "hit_count": + sub_query = ( + db.select( + DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count") + ) + .group_by(DocumentSegment.document_id) + .subquery() + ) + + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": + query = query.order_by( + sort_logic(Document.created_at), + sort_logic(Document.position), + ) + else: + query = query.order_by( + desc(Document.created_at), + desc(Document.position), + ) + + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + documents = paginated_documents.items + if fetch: for document in documents: completed_segments = ( session.query(DocumentSegment) @@ -228,17 +229,17 @@ class DatasetDocumentListApi(Resource): document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) - else: - data = marshal(documents, document_fields) - response = { - "data": data, - "has_more": len(documents) == limit, - "limit": limit, - "total": paginated_documents.total, - "page": page, - } + else: + data = marshal(documents, document_fields) + response = { + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, + } - return response + return response documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}