mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
fix: type mismatches (route says uuid: but handler says str) (#36612)
This commit is contained in:
parent
e617435d03
commit
345ba80942
@ -417,7 +417,7 @@ class MessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_model, message_id: str):
|
||||
def get(self, app_model, message_id: UUID):
|
||||
message_id_str = str(message_id)
|
||||
|
||||
message = db.session.scalar(
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
@ -345,14 +346,15 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, app_model: App, variable_id: str):
|
||||
def get(self, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
variable_id=variable_id_str,
|
||||
)
|
||||
return variable
|
||||
|
||||
@ -363,7 +365,7 @@ class VariableApi(Resource):
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def patch(self, app_model: App, variable_id: str):
|
||||
def patch(self, app_model: App, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -390,10 +392,11 @@ class VariableApi(Resource):
|
||||
)
|
||||
args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
variable_id=variable_id_str,
|
||||
)
|
||||
|
||||
new_name = args_model.name
|
||||
@ -434,14 +437,15 @@ class VariableApi(Resource):
|
||||
@console_ns.response(204, "Variable deleted successfully")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def delete(self, app_model: App, variable_id: str):
|
||||
def delete(self, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
variable_id=variable_id_str,
|
||||
)
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
@ -457,7 +461,7 @@ class VariableResetApi(Resource):
|
||||
@console_ns.response(204, "Variable reset (no content)")
|
||||
@console_ns.response(404, "Variable not found")
|
||||
@_api_prerequisite
|
||||
def put(self, app_model: App, variable_id: str):
|
||||
def put(self, app_model: App, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -468,10 +472,11 @@ class VariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, app_id={app_model.id}",
|
||||
)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = _ensure_variable_access(
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id),
|
||||
variable=draft_var_srv.get_variable(variable_id=variable_id_str),
|
||||
app_id=app_model.id,
|
||||
variable_id=variable_id,
|
||||
variable_id=variable_id_str,
|
||||
)
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
|
||||
@ -189,7 +189,7 @@ class WorkflowRunExportApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App, run_id: str):
|
||||
def get(self, app_model: App, run_id: UUID):
|
||||
tenant_id = str(app_model.tenant_id)
|
||||
app_id = str(app_model.id)
|
||||
run_id_str = str(run_id)
|
||||
|
||||
@ -979,7 +979,7 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
@ -996,7 +996,7 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: str):
|
||||
def post(self, dataset_id: UUID):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
@ -168,21 +169,22 @@ class RagPipelineVariableApi(Resource):
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
def get(self, pipeline: Pipeline, variable_id: str):
|
||||
def get(self, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
@marshal_with(workflow_draft_variable_model)
|
||||
@console_ns.expect(console_ns.models[WorkflowDraftVariablePatchPayload.__name__])
|
||||
def patch(self, pipeline: Pipeline, variable_id: str):
|
||||
def patch(self, pipeline: Pipeline, variable_id: UUID):
|
||||
# Request payload for file types:
|
||||
#
|
||||
# Local File:
|
||||
@ -210,11 +212,12 @@ class RagPipelineVariableApi(Resource):
|
||||
payload = WorkflowDraftVariablePatchPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
|
||||
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||
@ -250,15 +253,16 @@ class RagPipelineVariableApi(Resource):
|
||||
return variable
|
||||
|
||||
@_api_prerequisite
|
||||
def delete(self, pipeline: Pipeline, variable_id: str):
|
||||
def delete(self, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
draft_var_srv.delete_variable(variable)
|
||||
db.session.commit()
|
||||
return Response("", 204)
|
||||
@ -267,7 +271,7 @@ class RagPipelineVariableApi(Resource):
|
||||
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/draft/variables/<uuid:variable_id>/reset")
|
||||
class RagPipelineVariableResetApi(Resource):
|
||||
@_api_prerequisite
|
||||
def put(self, pipeline: Pipeline, variable_id: str):
|
||||
def put(self, pipeline: Pipeline, variable_id: UUID):
|
||||
draft_var_srv = WorkflowDraftVariableService(
|
||||
session=db.session(),
|
||||
)
|
||||
@ -278,11 +282,12 @@ class RagPipelineVariableResetApi(Resource):
|
||||
raise NotFoundError(
|
||||
f"Draft workflow not found, pipeline_id={pipeline.id}",
|
||||
)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||
variable_id_str = str(variable_id)
|
||||
variable = draft_var_srv.get_variable(variable_id=variable_id_str)
|
||||
if variable is None:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
if variable.app_id != pipeline.id:
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||
raise NotFoundError(description=f"variable not found, id={variable_id_str}")
|
||||
|
||||
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
|
||||
db.session.commit()
|
||||
|
||||
@ -901,7 +901,7 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def get(self, pipeline: Pipeline, run_id: str):
|
||||
def get(self, pipeline: Pipeline, run_id: UUID):
|
||||
"""
|
||||
Get workflow run node execution list
|
||||
"""
|
||||
|
||||
@ -174,11 +174,11 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def put(self, app_model: App, annotation_id: str):
|
||||
def put(self, app_model: App, annotation_id: UUID):
|
||||
"""Update an existing annotation."""
|
||||
payload = AnnotationCreatePayload.model_validate(service_api_ns.payload or {})
|
||||
update_args: UpdateAnnotationArgs = {"question": payload.question, "answer": payload.answer}
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, annotation_id)
|
||||
annotation = AppAnnotationService.update_app_annotation_directly(update_args, app_model.id, str(annotation_id))
|
||||
response = Annotation.model_validate(annotation, from_attributes=True)
|
||||
return response.model_dump(mode="json")
|
||||
|
||||
@ -195,7 +195,7 @@ class AnnotationUpdateDeleteApi(Resource):
|
||||
)
|
||||
@validate_app_token
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, annotation_id: str):
|
||||
def delete(self, app_model: App, annotation_id: UUID):
|
||||
"""Delete an annotation."""
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
|
||||
AppAnnotationService.delete_app_annotation(app_model.id, str(annotation_id))
|
||||
return "", 204
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
@ -50,20 +51,20 @@ class FilePreviewApi(Resource):
|
||||
}
|
||||
)
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: str):
|
||||
def get(self, app_model: App, end_user: EndUser, file_id: UUID):
|
||||
"""
|
||||
Preview/Download a file that was uploaded via Service API.
|
||||
|
||||
Provides secure file preview/download functionality.
|
||||
Files can only be accessed if they belong to messages within the requesting app's context.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
file_id_str = str(file_id)
|
||||
|
||||
# Parse query parameters
|
||||
args = FilePreviewQuery.model_validate(request.args.to_dict())
|
||||
|
||||
# Validate file ownership and get file objects
|
||||
_, upload_file = self._validate_file_ownership(file_id, app_model.id)
|
||||
_, upload_file = self._validate_file_ownership(file_id_str, app_model.id)
|
||||
|
||||
# Get file content generator
|
||||
try:
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from pydantic import BaseModel
|
||||
@ -64,10 +65,11 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str):
|
||||
def get(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Resource for getting datasource plugins."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -77,7 +79,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
datasource_plugins: list[dict[Any, Any]] = rag_pipeline_service.get_datasource_plugins(
|
||||
tenant_id=tenant_id, dataset_id=dataset_id, is_published=is_published
|
||||
tenant_id=tenant_id, dataset_id=dataset_id_str, is_published=is_published
|
||||
)
|
||||
return datasource_plugins, 200
|
||||
|
||||
@ -109,10 +111,11 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[DatasourceNodeRunPayload.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str, node_id: str):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, node_id: str):
|
||||
"""Resource for getting datasource plugins."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -120,7 +123,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
payload = DatasourceNodeRunPayload.model_validate(service_api_ns.payload or {})
|
||||
assert isinstance(current_user, Account)
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
|
||||
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(
|
||||
{
|
||||
**payload.model_dump(exclude_none=True),
|
||||
@ -172,10 +175,11 @@ class PipelineRunApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@service_api_ns.expect(service_api_ns.models[PipelineRunApiEntity.__name__])
|
||||
def post(self, tenant_id: str, dataset_id: str):
|
||||
def post(self, tenant_id: str, dataset_id: UUID):
|
||||
"""Resource for running a rag pipeline."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# Verify dataset ownership
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -186,7 +190,7 @@ class PipelineRunApi(DatasetApiResource):
|
||||
raise Forbidden()
|
||||
|
||||
rag_pipeline_service: RagPipelineService = RagPipelineService()
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)
|
||||
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id_str)
|
||||
try:
|
||||
response: dict[Any, Any] | Generator[str, Any, None] = PipelineGenerateService.generate(
|
||||
pipeline=pipeline,
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal
|
||||
@ -107,17 +108,19 @@ class SegmentApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create single segment."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if document.indexing_status != "completed":
|
||||
@ -150,7 +153,10 @@ class SegmentApi(DatasetApiResource):
|
||||
for args_item in payload.segments:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(payload.segments, document, dataset)
|
||||
return {"data": _marshal_segments_with_summary(segments, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id_str),
|
||||
"doc_form": document.doc_form,
|
||||
}, 200
|
||||
else:
|
||||
return {"error": "Segments is required"}, 400
|
||||
|
||||
@ -165,19 +171,21 @@ class SegmentApi(DatasetApiResource):
|
||||
404: "Dataset or document not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str):
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get segments."""
|
||||
# check dataset
|
||||
page = request.args.get("page", default=1, type=int)
|
||||
limit = request.args.get("limit", default=20, type=int)
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
# check embedding model setting
|
||||
@ -205,7 +213,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
segments, total = SegmentService.get_segments(
|
||||
document_id=document_id,
|
||||
document_id=document_id_str,
|
||||
tenant_id=current_tenant_id,
|
||||
status_list=args.status,
|
||||
keyword=args.keyword,
|
||||
@ -214,7 +222,7 @@ class SegmentApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
response = {
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id),
|
||||
"data": _marshal_segments_with_summary(segments, dataset_id_str),
|
||||
"doc_form": document.doc_form,
|
||||
"total": total,
|
||||
"has_more": len(segments) == limit,
|
||||
@ -240,22 +248,25 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
}
|
||||
)
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
@ -276,18 +287,20 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY:
|
||||
@ -306,15 +319,19 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {})
|
||||
|
||||
updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset)
|
||||
return {"data": _marshal_segment_with_summary(updated_segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {
|
||||
"data": _marshal_segment_with_summary(updated_segment, dataset_id_str),
|
||||
"doc_form": document.doc_form,
|
||||
}, 200
|
||||
|
||||
@service_api_ns.doc("get_segment")
|
||||
@service_api_ns.doc(description="Get a specific segment by ID")
|
||||
@ -325,26 +342,29 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": _marshal_segment_with_summary(segment, dataset_id_str), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@service_api_ns.route(
|
||||
@ -369,23 +389,26 @@ class ChildChunkApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Create child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
@ -429,23 +452,26 @@ class ChildChunkApi(DatasetApiResource):
|
||||
404: "Dataset, document, or segment not found",
|
||||
}
|
||||
)
|
||||
def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str):
|
||||
def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Get child chunks."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
@ -461,7 +487,9 @@ class ChildChunkApi(DatasetApiResource):
|
||||
limit = min(args.limit, 100)
|
||||
keyword = args.keyword
|
||||
|
||||
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
|
||||
child_chunks = SegmentService.get_child_chunks(
|
||||
segment_id_str, document_id_str, dataset_id_str, page, limit, keyword
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(child_chunks.items, child_chunk_fields),
|
||||
@ -497,32 +525,38 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
)
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Delete child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# check document
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# check segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id):
|
||||
if str(segment.document_id) != str(document_id_str):
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
# check child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
@ -558,32 +592,38 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
@cloud_edition_billing_resource_check("vector_space", "dataset")
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
|
||||
def patch(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str, child_chunk_id: str):
|
||||
def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
"""Update child chunk."""
|
||||
dataset_id_str = str(dataset_id)
|
||||
# check dataset
|
||||
dataset = db.session.scalar(
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1)
|
||||
select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id_str).limit(1)
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
document_id_str = str(document_id)
|
||||
# get document
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
document = DocumentService.get_document(dataset_id_str, document_id_str)
|
||||
if not document:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
segment_id_str = str(segment_id)
|
||||
# get segment
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_tenant_id)
|
||||
segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id)
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
# validate segment belongs to the specified document
|
||||
if str(segment.document_id) != str(document_id):
|
||||
if str(segment.document_id) != str(document_id_str):
|
||||
raise NotFound("Segment not found.")
|
||||
|
||||
child_chunk_id_str = str(child_chunk_id)
|
||||
# get child chunk
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(child_chunk_id=child_chunk_id, tenant_id=current_tenant_id)
|
||||
child_chunk = SegmentService.get_child_chunk_by_id(
|
||||
child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id
|
||||
)
|
||||
if not child_chunk:
|
||||
raise NotFound("Child chunk not found.")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user