From f79d8baf63778691d9c6d784305245ce8d6e650e Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 30 Sep 2025 00:38:59 +0900 Subject: [PATCH] Fix: Enable Pyright and Fix Typing Errors in Datasets Controller (#26425) Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/controllers/console/datasets/datasets.py | 18 +++++----- .../console/datasets/datasets_document.py | 15 +++++--- .../console/datasets/datasets_segments.py | 18 ++++++---- api/controllers/console/datasets/external.py | 7 ++-- .../console/datasets/hit_testing_base.py | 6 ++-- api/controllers/console/datasets/metadata.py | 3 +- .../datasets/rag_pipeline/rag_pipeline.py | 8 ++--- .../rag_pipeline/rag_pipeline_datasets.py | 16 ++------- .../rag_pipeline_draft_variable.py | 36 +++---------------- api/pyrightconfig.json | 1 - 10 files changed, 53 insertions(+), 75 deletions(-) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 2affbd6a42..60eedd2197 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,4 +1,5 @@ -import flask_restx +from typing import Any, cast + from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, marshal_with, reqparse @@ -31,12 +32,13 @@ from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fi from fields.document_fields import document_status_fields from libs.login import login_required from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile +from models.account import Account from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 40: raise ValueError("Name must be between 1 to 40 characters.") return name @@ -92,7 +94,7 @@ class DatasetListApi(Resource): for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - data = marshal(datasets, dataset_detail_fields) + data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields)) for item in data: # convert embedding_model_provider to plugin standard format if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: @@ -192,7 +194,7 @@ class DatasetListApi(Resource): name=args["name"], description=args["description"], indexing_technique=args["indexing_technique"], - account=current_user, + account=cast(Account, current_user), permission=DatasetPermissionEnum.ONLY_ME, provider=args["provider"], external_knowledge_api_id=args["external_knowledge_api_id"], @@ -224,7 +226,7 @@ class DatasetApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - data = marshal(dataset, dataset_detail_fields) + data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) if dataset.indexing_technique == "high_quality": if dataset.embedding_model_provider: provider_id = ModelProviderID(dataset.embedding_model_provider) @@ -369,7 +371,7 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") - result_data = marshal(dataset, dataset_detail_fields) + result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) tenant_id = current_user.current_tenant_id if data.get("partial_member_list") and data.get("permission") == "partial_members": @@ -688,7 +690,7 @@ class DatasetApiKeyApi(Resource): ) if current_key_count >= self.max_keys: - flask_restx.abort( + api.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", code="max_keys_exceeded", @@ -733,7 +735,7 @@ class DatasetApiDeleteApi(Resource): ) if key is None: - flask_restx.abort(404, message="API key not found") + api.abort(404, message="API key not found") db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() db.session.commit() diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6aaede0fb3..c5fa2061bf 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -55,6 +55,7 @@ from fields.document_fields import ( from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile +from models.account import Account from models.dataset import DocumentPipelineExecutionLog from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig @@ -418,7 +419,9 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user + tenant_id=current_user.current_tenant_id, + knowledge_config=knowledge_config, + account=cast(Account, current_user), ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -452,7 +455,7 @@ class DocumentIndexingEstimateApi(DocumentResource): raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule - data_process_rule_dict = data_process_rule.to_dict() + data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} @@ -514,7 +517,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): if not documents: return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule - data_process_rule_dict = data_process_rule.to_dict() + data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} extract_settings = [] for document in documents: if document.indexing_status in {"completed", "error"}: @@ -753,7 +756,7 @@ class DocumentApi(DocumentResource): } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -1073,7 +1076,9 @@ class DocumentRenameApi(DocumentResource): if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.get_dataset(dataset_id) - DatasetService.check_dataset_operator_permission(current_user, dataset) + if not dataset: + raise NotFound("Dataset not found.") + DatasetService.check_dataset_operator_permission(cast(Account, current_user), dataset) parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index ba552821d2..9f2805e2c6 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -392,7 +392,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource): # send batch add segments task redis_client.setnx(indexing_cache_key, "waiting") batch_create_segment_to_index_task.delay( - str(job_id), upload_file_id, dataset_id, document_id, current_user.current_tenant_id, current_user.id + str(job_id), + upload_file_id, + dataset_id, + document_id, + current_user.current_tenant_id, + current_user.id, ) except Exception as e: return {"error": str(e)}, 500 @@ -468,7 +473,8 @@ class ChildChunkAddApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + content = args["content"] + child_chunk = SegmentService.create_child_chunk(content, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 @@ -557,7 +563,8 @@ class ChildChunkAddApi(Resource): parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") args = parser.parse_args() try: - chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + chunks_data = args["chunks"] + chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data] child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) @@ -674,9 +681,8 @@ class ChildChunkUpdateApi(Resource): parser.add_argument("content", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - child_chunk = SegmentService.update_child_chunk( - args.get("content"), child_chunk, segment, document, dataset - ) + content = args["content"] + child_chunk = SegmentService.update_child_chunk(content, child_chunk, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return {"data": marshal(child_chunk, child_chunk_fields)}, 200 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index e8f5a11b41..adf9f53523 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -1,3 +1,5 @@ +from typing import cast + from flask import request from flask_login import current_user from flask_restx import Resource, fields, marshal, reqparse @@ -9,13 +11,14 @@ from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.wraps import account_initialization_required, setup_required from fields.dataset_fields import dataset_detail_fields from libs.login import login_required +from models.account import Account from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 100: raise ValueError("Name must be between 1 to 100 characters.") return name @@ -274,7 +277,7 @@ class ExternalKnowledgeHitTestingApi(Resource): response = HitTestingService.external_retrieve( dataset=dataset, query=args["query"], - account=current_user, + account=cast(Account, current_user), external_retrieval_model=args["external_retrieval_model"], metadata_filtering_conditions=args["metadata_filtering_conditions"], ) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cfbfc50873..a68e337135 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -1,10 +1,11 @@ import logging +from typing import cast from flask_login import current_user from flask_restx import marshal, reqparse from werkzeug.exceptions import Forbidden, InternalServerError, NotFound -import services.dataset_service +import services from controllers.console.app.error import ( CompletionRequestError, ProviderModelCurrentlyNotSupportError, @@ -20,6 +21,7 @@ from core.errors.error import ( ) from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields +from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -59,7 +61,7 @@ class DatasetsHitTestingBase: response = HitTestingService.retrieve( dataset=dataset, query=args["query"], - account=current_user, + account=cast(Account, current_user), retrieval_model=args["retrieval_model"], external_retrieval_model=args["external_retrieval_model"], limit=10, diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 53dc80eaa5..dc3cd3fce9 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -62,6 +62,7 @@ class DatasetMetadataApi(Resource): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() + name = args["name"] dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -70,7 +71,7 @@ class DatasetMetadataApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) + metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name) return metadata, 200 @setup_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 6641911243..3af590afc8 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -20,13 +20,13 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService logger = logging.getLogger(__name__) -def _validate_name(name): +def _validate_name(name: str) -> str: if not name or len(name) < 1 or len(name) > 40: raise ValueError("Name must be between 1 to 40 characters.") return name -def _validate_description_length(description): +def _validate_description_length(description: str) -> str: if len(description) > 400: raise ValueError("Description cannot exceed 400 characters.") return description @@ -76,7 +76,7 @@ class CustomizedPipelineTemplateApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", @@ -133,7 +133,7 @@ class PublishCustomizedPipelineTemplateApi(Resource): ) parser.add_argument( "description", - type=str, + type=_validate_description_length, nullable=True, required=False, default="", diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index c741bfbf82..404aa42073 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,5 +1,5 @@ -from flask_login import current_user # type: ignore # type: ignore -from flask_restx import Resource, marshal, reqparse # type: ignore +from flask_login import current_user +from flask_restx import Resource, marshal, reqparse from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -20,18 +20,6 @@ from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService -def _validate_name(name): - if not name or len(name) < 1 or len(name) > 40: - raise ValueError("Name must be between 1 to 40 characters.") - return name - - -def _validate_description_length(description): - if len(description) > 400: - raise ValueError("Description cannot exceed 400 characters.") - return description - - @console_ns.route("/rag/pipeline/dataset") class CreateRagPipelineDatasetApi(Resource): @setup_required diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 38f75402a8..bef6bfd13e 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,5 +1,5 @@ import logging -from typing import Any, NoReturn +from typing import NoReturn from flask import Response from flask_restx import Resource, fields, inputs, marshal, marshal_with, reqparse @@ -11,14 +11,12 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, ) from controllers.console.app.workflow_draft_variable import ( - _WORKFLOW_DRAFT_VARIABLE_FIELDS, - _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, + _WORKFLOW_DRAFT_VARIABLE_FIELDS, # type: ignore[private-usage] + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, # type: ignore[private-usage] ) from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.variables.segment_group import SegmentGroup -from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db @@ -34,32 +32,6 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, logger = logging.getLogger(__name__) -def _convert_values_to_json_serializable_object(value: Segment) -> Any: - if isinstance(value, FileSegment): - return value.value.model_dump() - elif isinstance(value, ArrayFileSegment): - return [i.model_dump() for i in value.value] - elif isinstance(value, SegmentGroup): - return [_convert_values_to_json_serializable_object(i) for i in value.value] - else: - return value.value - - -def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: - value = variable.get_value() - # create a copy of the value to avoid affecting the model cache. - value = value.model_copy(deep=True) - # Refresh the url signature before returning it to client. - if isinstance(value, FileSegment): - file = value.value - file.remote_url = file.generate_url() - elif isinstance(value, ArrayFileSegment): - files = value.value - for file in files: - file.remote_url = file.generate_url() - return _convert_values_to_json_serializable_object(value) - - def _create_pagination_parser(): parser = reqparse.RequestParser() parser.add_argument( @@ -104,7 +76,7 @@ def _api_prerequisite(f): @account_initialization_required @get_rag_pipeline def wrapper(*args, **kwargs): - if not isinstance(current_user, Account) or not current_user.is_editor: + if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 9cb1ea9bf1..1e6cd501ad 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -6,7 +6,6 @@ "migrations/", "core/rag", "extensions", - "controllers/console/datasets", "core/ops", "core/model_runtime", "core/workflow/nodes",