fix: type mismatches (route says uuid: but handler says str) (#36612)

This commit is contained in:
Lillian 2026-05-25 15:33:32 +08:00 committed by GitHub
parent e617435d03
commit 345ba80942
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 149 additions and 94 deletions

View File

@ -417,7 +417,7 @@ class MessageApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, app_model, message_id: str):
def get(self, app_model, message_id: UUID):
message_id_str = str(message_id)
message = db.session.scalar(

View File

@ -2,6 +2,7 @@ import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, fields, marshal, marshal_with
@ -345,14 +346,15 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, app_model: App, variable_id: str):
def get(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
return variable
@ -363,7 +365,7 @@ class VariableApi(Resource):
@console_ns.response(404, "Variable not found")
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def patch(self, app_model: App, variable_id: str):
def patch(self, app_model: App, variable_id: UUID):
# Request payload for file types:
#
# Local File:
@ -390,10 +392,11 @@ class VariableApi(Resource):
)
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
new_name = args_model.name
@ -434,14 +437,15 @@ class VariableApi(Resource):
@console_ns.response(204, "Variable deleted successfully")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
def delete(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
draft_var_srv.delete_variable(variable)
db.session.commit()
@ -457,7 +461,7 @@ class VariableResetApi(Resource):
@console_ns.response(204, "Variable reset (no content)")
@console_ns.response(404, "Variable not found")
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
def put(self, app_model: App, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -468,10 +472,11 @@ class VariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable_id_str = str(variable_id)
variable = _ensure_variable_access(
variable=draft_var_srv.get_variable(variable_id=variable_id),
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
app_id=app_model.id,
variable_id=variable_id,
variable_id=variable_id_str,
)
resetted = draft_var_srv.reset_variable(draft_workflow, variable)

View File

@ -189,7 +189,7 @@ class WorkflowRunExportApi(Resource):
@login_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App, run_id: str):
def get(self, app_model: App, run_id: UUID):
tenant_id = str(app_model.tenant_id)
app_id = str(app_model.id)
run_id_str = str(run_id)

View File

@ -979,7 +979,7 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
return {"url": DocumentService.get_document_download_url(document)}
@ -996,7 +996,7 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: str):
def post(self, dataset_id: UUID):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})

View File

@ -1,6 +1,7 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from uuid import UUID
from flask import Response, request
from flask_restx import Resource, marshal, marshal_with
@ -168,21 +169,22 @@ class RagPipelineVariableApi(Resource):
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
def get(self, pipeline: Pipeline, variable_id: str):
def get(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
return variable
@_api_prerequisite
@marshal_with(workflow_draft_variable_model)
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
def patch(self, pipeline: Pipeline, variable_id: str):
def patch(self, pipeline: Pipeline, variable_id: UUID):
# Request payload for file types:
#
# Local File:
@ -210,11 +212,12 @@ class RagPipelineVariableApi(Resource):
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
@ -250,15 +253,16 @@ class RagPipelineVariableApi(Resource):
return variable
@_api_prerequisite
def delete(self, pipeline: Pipeline, variable_id: str):
def delete(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
@ -267,7 +271,7 @@ class RagPipelineVariableApi(Resource):
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
class RagPipelineVariableResetApi(Resource):
@_api_prerequisite
def put(self, pipeline: Pipeline, variable_id: str):
def put(self, pipeline: Pipeline, variable_id: UUID):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@ -278,11 +282,12 @@ class RagPipelineVariableResetApi(Resource):
raise NotFoundError(
f"Draft workflow not found, pipeline_id={pipeline.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
variable_id_str = str(variable_id)
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
if variable.app_id != pipeline.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()

View File

@ -901,7 +901,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
@login_required
@account_initialization_required
@get_rag_pipeline
def get(self, pipeline: Pipeline, run_id: str):
def get(self, pipeline: Pipeline, run_id: UUID):
"""
Get workflow run node execution list
"""

View File

@ -174,11 +174,11 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def put(self, app_model: App, annotation_id: str):
def put(self, app_model: App, annotation_id: UUID):
"""Update an existing annotation."""
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
response = Annotation.model_validate(annotation, from_attributes=True)
return response.model_dump(mode="json")
@ -195,7 +195,7 @@ class AnnotationUpdateDeleteApi(Resource):
)
@validate_app_token
@edit_permission_required
def delete(self, app_model: App, annotation_id: str):
def delete(self, app_model: App, annotation_id: UUID):
"""Delete an annotation."""
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
return "", 204

View File

@ -1,5 +1,6 @@
import logging
from urllib.parse import quote
from uuid import UUID
from flask import Response, request
from flask_restx import Resource
@ -50,20 +51,20 @@ class FilePreviewApi(Resource):
}
)
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
def get(self, app_model: App, end_user: EndUser, file_id: str):
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
"""
Preview/Download a file that was uploaded via Service API.
Provides secure file preview/download functionality.
Files can only be accessed if they belong to messages within the requesting app's context.
"""
file_id = str(file_id)
file_id_str = str(file_id)
# Parse query parameters
args = FilePreviewQuery.model_validate(request.args.to_dict())
# Validate file ownership and get file objects
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
# Get file content generator
try:

View File

@ -1,5 +1,6 @@
from collections.abc import Generator
from typing import Any
from uuid import UUID
from flask import request
from pydantic import BaseModel
@ -64,10 +65,11 @@ class DatasourcePluginsApi(DatasetApiResource):
401: "Unauthorized - invalid API token",
}
)
def get(self, tenant_id: str, dataset_id: str):
def get(self, tenant_id: str, dataset_id: UUID):
"""Resource for getting datasource plugins."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -77,7 +79,7 @@ class DatasourcePluginsApi(DatasetApiResource):
rag_pipeline_service: RagPipelineService = RagPipelineService()
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
)
return datasource_plugins, 200
@ -109,10 +111,11 @@ class DatasourceNodeRunApi(DatasetApiResource):
}
)
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
def post(self, tenant_id: str, dataset_id: str, node_id: str):
def post(self, tenant_id: str, dataset_id: UUID, node_id: str):
"""Resource for getting datasource plugins."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -120,7 +123,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
{
**payload.model_dump(exclude_none=True),
@ -172,10 +175,11 @@ class PipelineRunApi(DatasetApiResource):
}
)
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
def post(self, tenant_id: str, dataset_id: str):
def post(self, tenant_id: str, dataset_id: UUID):
"""Resource for running a rag pipeline."""
dataset_id_str = str(dataset_id)
# Verify dataset ownership
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
dataset = db.session.scalar(stmt)
if not dataset:
raise NotFound("Dataset not found.")
@ -186,7 +190,7 @@ class PipelineRunApi(DatasetApiResource):
raise Forbidden()
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
try:
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
pipeline=pipeline,

View File

@ -1,4 +1,5 @@
from typing import Any
from uuid import UUID
from flask import request
from flask_restx import marshal
@ -107,17 +108,19 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Create single segment."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
if document.indexing_status != "completed":
@ -150,7 +153,10 @@ class SegmentApi(DatasetApiResource):
for args_item in payload.segments:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
return {
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
}, 200
else:
return {"error": "Segments is required"}, 400
@ -165,19 +171,21 @@ class SegmentApi(DatasetApiResource):
404: "Dataset or document not found",
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str):
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get segments."""
# check dataset
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
dataset_id_str = str(dataset_id)
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
# check embedding model setting
@ -205,7 +213,7 @@ class SegmentApi(DatasetApiResource):
)
segments, total = SegmentService.get_segments(
document_id=document_id,
document_id=document_id_str,
tenant_id=current_tenant_id,
status_list=args.status,
keyword=args.keyword,
@ -214,7 +222,7 @@ class SegmentApi(DatasetApiResource):
)
response = {
"data": _marshal_segments_with_summary(segments, dataset_id),
"data": _marshal_segments_with_summary(segments, dataset_id_str),
"doc_form": document.doc_form,
"total": total,
"has_more": len(segments) == limit,
@ -240,22 +248,25 @@ class DatasetSegmentApi(DatasetApiResource):
}
)
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
@ -276,18 +287,20 @@ class DatasetSegmentApi(DatasetApiResource):
)
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
@ -306,15 +319,19 @@ class DatasetSegmentApi(DatasetApiResource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
return {
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
"doc_form": document.doc_form,
}, 200
@service_api_ns.doc("get_segment")
@service_api_ns.doc(description="Get a specific segment by ID")
@ -325,26 +342,29 @@ class DatasetSegmentApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
@service_api_ns.route(
@ -369,23 +389,26 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Create child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -429,23 +452,26 @@ class ChildChunkApi(DatasetApiResource):
404: "Dataset, document, or segment not found",
}
)
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Get child chunks."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
@ -461,7 +487,9 @@ class ChildChunkApi(DatasetApiResource):
limit = min(args.limit, 100)
keyword = args.keyword
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
child_chunks = SegmentService.get_child_chunks(
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
@ -497,32 +525,38 @@ class DatasetChildChunkApi(DatasetApiResource):
)
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Delete child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# check document
document = DocumentService.get_document(dataset.id, document_id)
document = DocumentService.get_document(dataset.id, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# check segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id):
if str(segment.document_id) != str(document_id_str):
raise NotFound("Document not found.")
child_chunk_id_str = str(child_chunk_id)
# check child chunk
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")
@ -558,32 +592,38 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
_, current_tenant_id = current_account_with_tenant()
"""Update child chunk."""
dataset_id_str = str(dataset_id)
# check dataset
dataset = db.session.scalar(
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
)
if not dataset:
raise NotFound("Dataset not found.")
document_id_str = str(document_id)
# get document
document = DocumentService.get_document(dataset_id, document_id)
document = DocumentService.get_document(dataset_id_str, document_id_str)
if not document:
raise NotFound("Document not found.")
segment_id_str = str(segment_id)
# get segment
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
if not segment:
raise NotFound("Segment not found.")
# validate segment belongs to the specified document
if str(segment.document_id) != str(document_id):
if str(segment.document_id) != str(document_id_str):
raise NotFound("Segment not found.")
child_chunk_id_str = str(child_chunk_id)
# get child chunk
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
child_chunk = SegmentService.get_child_chunk_by_id(
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
)
if not child_chunk:
raise NotFound("Child chunk not found.")