diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 5aa243597a..4c3cbce832 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable from functools import wraps -from typing import Any, TypedDict +from typing import Any, Concatenate, TypedDict from uuid import UUID from flask import Response, request @@ -214,7 +214,9 @@ workflow_draft_variable_list_model = console_ns.model( ) -def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: +def _api_prerequisite[T, **P, R]( + f: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R | Response]: """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -231,8 +233,8 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: @edit_permission_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) @wraps(f) - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: - return f(*args, **kwargs) + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response: + return f(self, *args, **kwargs) return wrapper diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index b736c3129d..07db712fba 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -24,14 +24,14 @@ from extensions.ext_database import db from fields.base import ResponseModel from libs.datetime_utils import naive_utc_now from libs.helper import dump_response, to_timestamp -from libs.login import current_account_with_tenant, login_required -from models import DataSourceOauthBinding, Document +from libs.login import login_required +from models import Account, DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService from services.datasource_provider_service import DatasourceProviderService from tasks.document_indexing_sync_task import document_indexing_sync_task from .. import console_ns -from ..wraps import account_initialization_required, setup_required +from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user class NotionEstimatePayload(BaseModel): @@ -130,9 +130,8 @@ class DataSourceApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__]) - def get(self) -> tuple[dict[str, Any], int]: - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]: # get workspace data source integrates data_source_integrates = db.session.scalars( select(DataSourceOauthBinding).where( @@ -180,8 +179,10 @@ class DataSourceApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def patch(self, binding_id: UUID, action: Literal["enable", "disable"]) -> tuple[dict[str, str], int]: - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def patch( + self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"] + ) -> tuple[dict[str, str], int]: binding_id_str = str(binding_id) with sessionmaker(db.engine, expire_on_commit=False).begin() as session: data_source_binding = session.execute( @@ -220,9 +221,9 @@ class DataSourceNotionListApi(Resource): @account_initialization_required @console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery)) @console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__]) - def get(self) -> tuple[dict[str, Any], int]: - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]: query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True)) datasource_provider_service = DatasourceProviderService() credential = datasource_provider_service.get_datasource_credentials( @@ -311,9 +312,8 @@ class DataSourceNotionPreviewApi(Resource): @account_initialization_required @console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery)) @console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__]) - def get(self, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]: - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]: query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True)) datasource_provider_service = DatasourceProviderService() @@ -347,9 +347,8 @@ class DataSourceNotionIndexingEstimateApi(Resource): @account_initialization_required @console_ns.expect(console_ns.models[NotionEstimatePayload.__name__]) @console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__]) - def post(self) -> tuple[dict[str, Any], int]: - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]: payload = NotionEstimatePayload.model_validate(console_ns.payload or {}) args = payload.model_dump() # validate args diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 5b1588e2dc..401aa2454e 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -44,8 +44,8 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from libs.datetime_utils import naive_utc_now from libs.helper import dump_response, to_timestamp -from libs.login import current_account_with_tenant, login_required -from models import DatasetProcessRule, Document, DocumentSegment, UploadFile +from libs.login import login_required +from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog from models.enums import IndexingStatus, SegmentStatus from services.dataset_service import DatasetService, DocumentService @@ -71,6 +71,8 @@ from ..wraps import ( cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, setup_required, + with_current_tenant_id, + with_current_user, ) logger = logging.getLogger(__name__) @@ -169,8 +171,9 @@ register_response_schema_models( class DocumentResource(Resource): - def get_document(self, dataset_id: str, document_id: str) -> Document: - current_user, current_tenant_id = current_account_with_tenant() + def get_document( + self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str + ) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") @@ -190,8 +193,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]: - current_user, _ = current_account_with_tenant() + def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") @@ -218,8 +220,8 @@ class GetProcessRuleApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account): req_data = request.args document_id = req_data.get("document_id") @@ -279,8 +281,9 @@ class DatasetDocumentListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) raw_args = request.args.to_dict() param = DocumentDatasetListParam.model_validate(raw_args) @@ -405,8 +408,8 @@ class DatasetDocumentListApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[KnowledgeConfig.__name__]) @console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__]) - def post(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -480,9 +483,10 @@ class DatasetInitApi(Resource): @account_initialization_required @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor - current_user, current_tenant_id = current_account_with_tenant() if not current_user.is_dataset_editor: raise Forbidden() @@ -539,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, document_id: UUID): - _, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - document = self.get_document(dataset_id_str, document_id_str) + document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() @@ -604,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, batch: str): - _, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str): dataset_id_str = str(dataset_id) - documents = self.get_batch_documents(dataset_id_str, batch) + documents = self.get_batch_documents(dataset_id_str, batch, current_user) if not documents: return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200 data_process_rule = documents[0].dataset_process_rule @@ -704,9 +710,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, batch: str): + @with_current_user + def get(self, current_user: Account, dataset_id: UUID, batch: str): dataset_id_str = str(dataset_id) - documents = self.get_batch_documents(dataset_id_str, batch) + documents = self.get_batch_documents(dataset_id_str, batch, current_user) documents_status = [] for document in documents: completed_segments = ( @@ -759,16 +766,18 @@ class DocumentIndexingStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, document_id: UUID): + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - document = self.get_document(dataset_id_str, document_id_str) + document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) completed_segments = ( db.session.scalar( select(func.count(DocumentSegment.id)).where( DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id_str), + DocumentSegment.document_id == document_id_str, DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) ) @@ -777,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource): total_segments = ( db.session.scalar( select(func.count(DocumentSegment.id)).where( - DocumentSegment.document_id == str(document_id_str), + DocumentSegment.document_id == document_id_str, DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) ) @@ -820,10 +829,12 @@ class DocumentApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, document_id: UUID): + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - document = self.get_document(dataset_id_str, document_id_str) + document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: @@ -909,7 +920,9 @@ class DocumentApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(204, "Document deleted successfully") - def delete(self, dataset_id: UUID, document_id: UUID): + @with_current_user + @with_current_tenant_id + def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -918,7 +931,7 @@ class DocumentApi(DocumentResource): # check user's model setting DatasetService.check_dataset_model_setting(dataset) - document = self.get_document(dataset_id_str, document_id_str) + document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) try: DocumentService.delete_document(document) @@ -939,9 +952,11 @@ class DocumentDownloadApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]: + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]: # Reuse the shared permission/tenant checks implemented in DocumentResource. - document = self.get_document(str(dataset_id), str(document_id)) + document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id) return {"url": DocumentService.get_document_download_url(document)} @@ -956,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource): @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__]) - def post(self, dataset_id: UUID): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): """Stream a ZIP archive containing the requested uploaded documents.""" # Parse and validate request payload. payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {}) - current_user, current_tenant_id = current_account_with_tenant() dataset_id_str = str(dataset_id) document_ids: list[str] = [str(document_id) for document_id in payload.document_ids] upload_files, download_name = DocumentService.prepare_document_batch_download_zip( @@ -1003,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]): - current_user, _ = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def patch( + self, + current_tenant_id: str, + current_user: Account, + dataset_id: UUID, + document_id: UUID, + action: Literal["pause", "resume"], + ): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - document = self.get_document(dataset_id_str, document_id_str) + document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor if not current_user.is_dataset_editor: @@ -1051,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource): @setup_required @login_required @account_initialization_required - def put(self, dataset_id: UUID, document_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - document = self.get_document(dataset_id_str, document_id_str) + document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {}) @@ -1100,8 +1125,10 @@ class DocumentStatusApi(DocumentResource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]): - current_user, _ = current_account_with_tenant() + @with_current_user + def patch( + self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"] + ): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: @@ -1216,8 +1243,6 @@ class DocumentRetryApi(DocumentResource): raise NotFound("Dataset not found.") for document_id in payload.document_ids: try: - document_id = str(document_id) - document = DocumentService.get_document(dataset.id, document_id) # 404 if document not found @@ -1248,9 +1273,9 @@ class DocumentRenameApi(DocumentResource): @account_initialization_required @console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__]) @console_ns.expect(console_ns.models[DocumentRenamePayload.__name__]) - def post(self, dataset_id: UUID, document_id: UUID): + @with_current_user + def post(self, current_user: Account, dataset_id: UUID, document_id: UUID): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator - current_user, _ = current_account_with_tenant() if not current_user.is_dataset_editor: raise Forbidden() dataset = DatasetService.get_dataset(dataset_id) @@ -1273,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def get(self, dataset_id: UUID, document_id: UUID): + @with_current_tenant_id + def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID): """sync website document.""" - _, current_tenant_id = current_account_with_tenant() dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: @@ -1351,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id: UUID): + @with_current_user + def post(self, current_user: Account, dataset_id: UUID): """ Generate summary index for specified documents. @@ -1359,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource): (indexing_technique must be 'high_quality' and summary_index_setting.enable must be true), then asynchronously generates summary indexes for the provided documents. """ - current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) # Get dataset @@ -1444,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, document_id: UUID): + @with_current_user + def get(self, current_user: Account, dataset_id: UUID, document_id: UUID): """ Get summary index generation status for a document. @@ -1457,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource): - not_started: Number of segments without summary records - summaries: List of summary records with status and content preview """ - current_user, _ = current_account_with_tenant() dataset_id_str = str(dataset_id) document_id_str = str(document_id) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 77a6462427..4e521100ab 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -33,6 +33,8 @@ from controllers.console.wraps import ( cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, setup_required, + with_current_tenant_id, + with_current_user, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager @@ -51,7 +53,8 @@ from fields.segment_fields import ( ) from graphon.model_runtime.entities.model_entities import ModelType from libs.helper import dump_response, escape_like_pattern -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService, SegmentService @@ -164,9 +167,9 @@ class DatasetDocumentSegmentListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, document_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -274,9 +277,8 @@ class DatasetDocumentSegmentListApi(Resource): @console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT) @console_ns.doc(params=query_params_from_model(SegmentIdListQuery)) @console_ns.response(204, "Segments deleted successfully") - def delete(self, dataset_id: UUID, document_id: UUID): - current_user, _ = current_account_with_tenant() - + @with_current_user + def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -312,9 +314,16 @@ class DatasetDocumentSegmentApi(Resource): @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def patch( + self, + current_tenant_id: str, + current_user: Account, + dataset_id: UUID, + document_id: UUID, + action: Literal["enable", "disable"], + ): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if not dataset: @@ -373,9 +382,9 @@ class DatasetDocumentSegmentAddApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[SegmentCreatePayload.__name__]) @console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__]) - def post(self, dataset_id: UUID, document_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -431,9 +440,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__]) @console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__]) - def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def patch( + self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID + ): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -500,9 +511,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT) @console_ns.response(204, "Segment deleted successfully") - def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def delete( + self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID + ): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -548,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource): @cloud_edition_billing_knowledge_limit_check("add_segment") @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[BatchImportPayload.__name__]) - def post(self, dataset_id: UUID, document_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -619,9 +632,11 @@ class ChildChunkAddApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__]) @console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__]) - def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def post( + self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID + ): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -677,9 +692,8 @@ class ChildChunkAddApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -731,9 +745,11 @@ class ChildChunkAddApi(Resource): console_ns.models[ChildChunkBatchUpdateResponse.__name__], ) @console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__]) - def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def patch( + self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID + ): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -781,9 +797,17 @@ class ChildChunkUpdateApi(Resource): @cloud_edition_billing_rate_limit_check("knowledge") @console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK) @console_ns.response(204, "Child chunk deleted successfully") - def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def delete( + self, + current_tenant_id: str, + current_user: Account, + dataset_id: UUID, + document_id: UUID, + segment_id: UUID, + child_chunk_id: UUID, + ): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -840,9 +864,17 @@ class ChildChunkUpdateApi(Resource): @console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK) @console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__]) @console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__]) - def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def patch( + self, + current_tenant_id: str, + current_user: Account, + dataset_id: UUID, + document_id: UUID, + segment_id: UUID, + child_chunk_id: UUID, + ): # check dataset dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 71b3ab32ef..2bd9f12b29 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -15,6 +15,7 @@ from controllers.console.wraps import ( edit_permission_required, setup_required, with_current_tenant_id, + with_current_user, ) from fields.dataset_fields import ( dataset_detail_fields, @@ -29,7 +30,8 @@ from fields.dataset_fields import ( vector_setting_fields, weighted_score_fields, ) -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService @@ -152,8 +154,9 @@ class ExternalApiTemplateListApi(Resource): @login_required @account_initialization_required @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) - def post(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) ExternalDatasetService.validate_api_list(payload.settings) @@ -182,8 +185,8 @@ class ExternalApiTemplateApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, external_knowledge_api_id: UUID): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str, external_knowledge_api_id: UUID): external_knowledge_api_id_str = str(external_knowledge_api_id) external_knowledge_api = ExternalDatasetService.get_external_knowledge_api( external_knowledge_api_id_str, current_tenant_id @@ -197,8 +200,9 @@ class ExternalApiTemplateApi(Resource): @login_required @account_initialization_required @console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__]) - def patch(self, external_knowledge_api_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID): external_knowledge_api_id_str = str(external_knowledge_api_id) payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {}) @@ -217,8 +221,9 @@ class ExternalApiTemplateApi(Resource): @login_required @account_initialization_required @console_ns.response(204, "External knowledge API deleted successfully") - def delete(self, external_knowledge_api_id: UUID): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID): external_knowledge_api_id_str = str(external_knowledge_api_id) if not (current_user.has_edit_permission or current_user.is_dataset_operator): @@ -237,8 +242,8 @@ class ExternalApiUseCheckApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, external_knowledge_api_id: UUID): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str, external_knowledge_api_id: UUID): external_knowledge_api_id_str = str(external_knowledge_api_id) external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check( @@ -259,9 +264,10 @@ class ExternalDatasetCreateApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self): + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account): # The role of the current user in the ta table must be admin, owner, or editor - current_user, current_tenant_id = current_account_with_tenant() payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {}) args = payload.model_dump(exclude_none=True) @@ -293,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, dataset_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 6445f06297..c7cae85d37 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -9,11 +9,18 @@ from configs import dify_config from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from core.plugin.impl.oauth import OAuthHandler from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService @@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource): @login_required @account_initialization_required @edit_permission_required - def get(self, provider_id: str): - current_user, current_tenant_id = current_account_with_tenant() - + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account, provider_id: str): tenant_id = current_tenant_id - credential_id = request.args.get("credential_id") datasource_provider_id = DatasourceProviderID(provider_id) provider_name = datasource_provider_id.provider_name @@ -174,9 +180,8 @@ class DatasourceAuth(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str, provider_id: str): payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -195,10 +200,11 @@ class DatasourceAuth(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider_id: str): + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, user: Account, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() - user, current_tenant_id = current_account_with_tenant() datasources = datasource_provider_service.list_datasource_credentials( tenant_id=current_tenant_id, @@ -217,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) plugin_id = datasource_provider_id.plugin_id provider_name = datasource_provider_id.provider_name @@ -242,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {}) @@ -265,9 +269,8 @@ class DatasourceAuthListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, current_tenant_id: str): datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id) return {"result": jsonable_encoder(datasources)}, 200 @@ -278,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def get(self, current_tenant_id: str): datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id) return {"result": jsonable_encoder(datasources)}, 200 @@ -293,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str, provider_id: str): payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -311,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def delete(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def delete(self, current_tenant_id: str, provider_id: str): datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_oauth_custom_client_params( @@ -331,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str, provider_id: str): payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() @@ -353,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource): @login_required @account_initialization_required @edit_permission_required - def post(self, provider_id: str): - _, current_tenant_id = current_account_with_tenant() - + @with_current_tenant_id + def post(self, current_tenant_id: str, provider_id: str): payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {}) datasource_provider_id = DatasourceProviderID(provider_id) datasource_provider_service = DatasourceProviderService() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 4baedf662e..f6fc5afc78 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable -from typing import Any, NoReturn +from typing import Any, Concatenate, NoReturn from uuid import UUID from flask import Response, request @@ -57,7 +57,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel): register_schema_models(console_ns, WorkflowDraftVariablePatchPayload) -def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: +def _api_prerequisite[T, **P, R]( + f: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R | Response]: """Common prerequisites for all draft workflow variable APIs. It ensures the following conditions are satisfied: @@ -72,10 +74,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]: @login_required @account_initialization_required @get_rag_pipeline - def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response: + def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response: if not isinstance(current_user, Account) or not current_user.has_edit_permission: raise Forbidden() - return f(*args, **kwargs) + return f(self, *args, **kwargs) return wrapper diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 984a128376..e58f34dc3b 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -42,6 +42,8 @@ from controllers.console.wraps import ( enterprise_license_required, only_edition_cloud, setup_required, + with_current_tenant_id, + with_current_user, ) from extensions.ext_database import db from fields.base import ResponseModel @@ -49,8 +51,8 @@ from fields.member_fields import Account as AccountResponse from graphon.file import helpers as file_helpers from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp -from libs.login import current_account_with_tenant, login_required -from models import AccountIntegrate, InvitationCode +from libs.login import login_required +from models import Account, AccountIntegrate, InvitationCode from models.account import AccountStatus, InvitationCodeStatus from models.enums import CreatorUserRole from models.model import UploadFile @@ -258,9 +260,8 @@ class AccountInitApi(Resource): @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) @setup_required @login_required - def post(self): - account, _ = current_account_with_tenant() - + @with_current_user + def post(self, account: Account): if account.status == "active": raise AccountAlreadyInitedError() @@ -306,8 +307,8 @@ class AccountProfileApi(Resource): @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) @enterprise_license_required - def get(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account): return _serialize_account(current_user) @@ -318,8 +319,8 @@ class AccountNameApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountNamePayload.model_validate(payload) updated_account = AccountService.update_account(current_user, name=args.name) @@ -336,8 +337,9 @@ class AccountAvatarApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) avatar = args.avatar @@ -362,8 +364,8 @@ class AccountAvatarApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountAvatarPayload.model_validate(payload) @@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountInterfaceLanguagePayload.model_validate(payload) @@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountInterfaceThemePayload.model_validate(payload) @@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountTimezonePayload.model_validate(payload) @@ -430,8 +432,8 @@ class AccountPasswordApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = AccountPasswordPayload.model_validate(payload) @@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__]) - def get(self): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account): account_integrates = db.session.scalars( select(AccountIntegrate).where(AccountIntegrate.account_id == account.id) ).all() @@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__]) - def get(self): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account): token, code = AccountService.generate_account_deletion_verification_code(account) AccountService.send_account_deletion_verification_email(account, code) @@ -511,9 +511,8 @@ class AccountDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - account, _ = current_account_with_tenant() - + @with_current_user + def post(self, account: Account): payload = console_ns.payload or {} args = AccountDeletePayload.model_validate(payload) @@ -547,9 +546,8 @@ class EducationVerifyApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled @console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__]) - def get(self): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account): return EducationVerifyResponse.model_validate( BillingService.EducationIdentity.verify(account.id, account.email) or {} ).model_dump(mode="json") @@ -563,9 +561,8 @@ class EducationApi(Resource): @account_initialization_required @only_edition_cloud @cloud_edition_billing_enabled - def post(self): - account, _ = current_account_with_tenant() - + @with_current_user + def post(self, account: Account): payload = console_ns.payload or {} args = EducationActivatePayload.model_validate(payload) @@ -577,9 +574,8 @@ class EducationApi(Resource): @only_edition_cloud @cloud_edition_billing_enabled @console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__]) - def get(self): - account, _ = current_account_with_tenant() - + @with_current_user + def get(self, account: Account): res = BillingService.EducationIdentity.status(account.id) or {} # convert expire_at to UTC timestamp from isoformat if res and "expire_at" in res: @@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = ChangeEmailSendPayload.model_validate(payload) @@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = ChangeEmailValidityPayload.model_validate(payload) @@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__]) - def post(self): + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = ChangeEmailResetPayload.model_validate(payload) normalized_new_email = args.new_email.lower() @@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource): if not AccountService.check_email_unique(normalized_new_email): raise EmailAlreadyInUseError() - current_user, _ = current_account_with_tenant() reset_data = AccountService.get_change_email_data(args.token) if not reset_data: raise InvalidTokenError() diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 925f3e1197..c539debf08 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -14,10 +14,16 @@ from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + is_admin_or_owner_required, + setup_required, + with_current_tenant_id, + with_current_user_id, +) from core.plugin.impl.exc import PluginPermissionDeniedError from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required from services.plugin.endpoint_service import EndpointService @@ -96,17 +102,15 @@ register_schema_models( ) -def _create_endpoint() -> dict[str, bool]: - """Create a plugin endpoint for the current workspace.""" - user, tenant_id = current_account_with_tenant() - +def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]: + """Create a plugin endpoint for the injected workspace and user.""" args = EndpointCreatePayload.model_validate(console_ns.payload) try: return { "success": EndpointService.create_endpoint( tenant_id=tenant_id, - user_id=user.id, + user_id=user_id, plugin_unique_identifier=args.plugin_unique_identifier, name=args.name, settings=args.settings, @@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]: raise ValueError(e.description) from e -def _update_endpoint(endpoint_id: str) -> dict[str, bool]: +def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]: """Update a plugin endpoint identified by the canonical path parameter.""" - user, tenant_id = current_account_with_tenant() - args = EndpointUpdatePayload.model_validate(console_ns.payload) return { "success": EndpointService.update_endpoint( tenant_id=tenant_id, - user_id=user.id, + user_id=user_id, endpoint_id=endpoint_id, name=args.name, settings=args.settings, @@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]: } -def _delete_endpoint(endpoint_id: str) -> dict[str, bool]: +def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]: """Delete a plugin endpoint identified by the canonical path parameter.""" - user, tenant_id = current_account_with_tenant() - return { "success": EndpointService.delete_endpoint( tenant_id=tenant_id, - user_id=user.id, + user_id=user_id, endpoint_id=endpoint_id, ) } @@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - return _create_endpoint() + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, user_id: str): + return _create_endpoint(tenant_id=tenant_id, user_id=user_id) @console_ns.route("/workspaces/current/endpoints/create") @@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - return _create_endpoint() + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, user_id: str): + return _create_endpoint(tenant_id=tenant_id, user_id=user_id) @console_ns.route("/workspaces/current/endpoints/list") @@ -206,9 +210,9 @@ class EndpointListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - + @with_current_user_id + @with_current_tenant_id + def get(self, tenant_id: str, user_id: str): args = EndpointListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page @@ -218,7 +222,7 @@ class EndpointListApi(Resource): { "endpoints": EndpointService.list_endpoints( tenant_id=tenant_id, - user_id=user.id, + user_id=user_id, page=page, page_size=page_size, ) @@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - user, tenant_id = current_account_with_tenant() - + @with_current_user_id + @with_current_tenant_id + def get(self, tenant_id: str, user_id: str): args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True)) page = args.page @@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource): { "endpoints": EndpointService.list_endpoints_for_single_plugin( tenant_id=tenant_id, - user_id=user.id, + user_id=user_id, plugin_id=plugin_id, page=page, page_size=page_size, @@ -278,8 +282,10 @@ class EndpointItemApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, id: str): - return _delete_endpoint(endpoint_id=id) + @with_current_user_id + @with_current_tenant_id + def delete(self, tenant_id: str, user_id: str, id: str): + return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id) @console_ns.doc("update_endpoint") @console_ns.doc(description="Update a plugin endpoint") @@ -295,8 +301,10 @@ class EndpointItemApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def patch(self, id: str): - return _update_endpoint(endpoint_id=id) + @with_current_user_id + @with_current_tenant_id + def patch(self, tenant_id: str, user_id: str, id: str): + return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id) @console_ns.route("/workspaces/current/endpoints/delete") @@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, user_id: str): args = EndpointIdPayload.model_validate(console_ns.payload) - return _delete_endpoint(endpoint_id=args.endpoint_id) + return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id) @console_ns.route("/workspaces/current/endpoints/update") @@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, user_id: str): args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload) - return _update_endpoint(endpoint_id=args.endpoint_id) + return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id) @console_ns.route("/workspaces/current/endpoints/enable") @@ -370,14 +382,14 @@ class EndpointEnableApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, user_id: str): args = EndpointIdPayload.model_validate(console_ns.payload) return { "success": EndpointService.enable_endpoint( - tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id + tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id ) } @@ -397,13 +409,13 @@ class EndpointDisableApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self): - user, tenant_id = current_account_with_tenant() - + @with_current_user_id + @with_current_tenant_id + def post(self, tenant_id: str, user_id: str): args = EndpointIdPayload.model_validate(console_ns.payload) return { "success": EndpointService.disable_endpoint( - tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id + tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id ) } diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 2cb1aeaaf8..60ecaa16bd 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -25,13 +25,15 @@ from controllers.console.wraps import ( cloud_edition_billing_resource_check, only_edition_enterprise, setup_required, + with_current_tenant_id, + with_current_user, ) from enums.cloud_plan import CloudPlan from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import TimestampField, dump_response, to_timestamp -from libs.login import current_account_with_tenant, login_required -from models.account import Tenant, TenantCustomConfigDict, TenantStatus +from libs.login import login_required +from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus from services.account_service import TenantService from services.billing_service import BillingService, SubscriptionPlan from services.enterprise.enterprise_service import EnterpriseService @@ -153,8 +155,9 @@ class TenantListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): - current_user, current_tenant_id = current_account_with_tenant() + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED @@ -228,11 +231,11 @@ class TenantApi(Resource): @login_required @account_initialization_required @console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__]) - def post(self): + @with_current_user + def post(self, current_user: Account): if request.path == "/info": logger.warning("Deprecated URL /info was used.") - current_user, _ = current_account_with_tenant() tenant = current_user.current_tenant if not tenant: raise ValueError("No current tenant") @@ -256,8 +259,8 @@ class SwitchWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): payload = console_ns.payload or {} args = SwitchWorkspacePayload.model_validate(payload) @@ -281,8 +284,8 @@ class CustomConfigWorkspaceApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") - def post(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str): payload = console_ns.payload or {} args = WorkspaceCustomConfigPayload.model_validate(payload) tenant = db.get_or_404(Tenant, current_tenant_id) @@ -308,8 +311,8 @@ class WebappLogoWorkspaceApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("workspace_custom") - def post(self): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account): # check file if "file" not in request.files: raise NoFileUploadedError() @@ -349,8 +352,8 @@ class WorkspaceInfoApi(Resource): @login_required @account_initialization_required # Change workspace name - def post(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def post(self, current_tenant_id: str): payload = console_ns.payload or {} args = WorkspaceInfoPayload.model_validate(payload) @@ -372,13 +375,12 @@ class WorkspacePermissionApi(Resource): @login_required @account_initialization_required @only_edition_enterprise - def get(self): + @with_current_tenant_id + def get(self, current_tenant_id: str): """ Get workspace permission settings. Returns permission flags that control workspace features like member invitations and owner transfer. """ - _, current_tenant_id = current_account_with_tenant() - if not current_tenant_id: raise ValueError("No current tenant") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index ed67c24bd3..1c0f210123 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -4,7 +4,7 @@ import os import time from collections.abc import Callable from functools import wraps -from typing import Concatenate +from typing import Any, Concatenate, overload from flask import abort, request from pydantic import BaseModel, ValidationError @@ -37,9 +37,21 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data" ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code" -def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: +@overload +def account_initialization_required[T, **P, R]( + view: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R]: ... + + +@overload +def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ... + + +def account_initialization_required[R](view: Callable[..., R]) -> Callable[..., R]: @wraps(view) - def decorated(*args: P.args, **kwargs: P.kwargs) -> R: + def decorated(*args: Any, **kwargs: Any) -> R: + # The overloads keep Resource methods method-aware for pyrefly while + # preserving support for plain functions used in tests and utilities. # check account initialization current_user, _ = current_account_with_tenant() if current_user.status == AccountStatus.UNINITIALIZED: @@ -218,9 +230,21 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]: return decorated -def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: +@overload +def setup_required[T, **P, R]( + view: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R]: ... + + +@overload +def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ... + + +def setup_required[R](view: Callable[..., R]) -> Callable[..., R]: @wraps(view) - def decorated(*args: P.args, **kwargs: P.kwargs) -> R: + def decorated(*args: Any, **kwargs: Any) -> R: + # The overloads keep Resource methods method-aware for pyrefly while + # preserving support for plain functions used in tests and utilities. # check setup if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)): if os.environ.get("INIT_PASSWORD"): @@ -552,7 +576,7 @@ def with_current_user_id[T, **P, R]( @wraps(view) def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R: current_user, _ = current_account_with_tenant() - return view(self, str(current_user.id), *args, **kwargs) + return view(self, current_user.id, *args, **kwargs) return decorated diff --git a/api/libs/login.py b/api/libs/login.py index 067597cb3c..12d0f53f2d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Concatenate, cast, overload from flask import Response, current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS @@ -48,7 +48,17 @@ def current_account_with_tenant() -> tuple[Account, str]: return user, user.current_tenant_id -def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: +@overload +def login_required[T, **P, R]( + func: Callable[Concatenate[T, P], R], +) -> Callable[Concatenate[T, P], R | Response]: ... + + +@overload +def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: ... + + +def login_required[R](func: Callable[..., R]) -> Callable[..., R | Response]: """ If you decorate a view with this, it will ensure that the current user is logged in and authenticated before calling the actual view. (If they are @@ -83,7 +93,9 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: """ @wraps(func) - def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response: + def decorated_view(*args: Any, **kwargs: Any) -> R | Response: + # The overloads keep Resource methods method-aware for pyrefly while + # preserving support for plain Flask view functions. if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 54b82df4e5..c644281190 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -2,6 +2,8 @@ from __future__ import annotations +import inspect +from collections.abc import Iterator from datetime import UTC, datetime from unittest.mock import MagicMock, PropertyMock, patch @@ -19,31 +21,18 @@ from controllers.console.datasets.data_source import ( DataSourceNotionPreviewApi, ) from core.rag.index_processor.constant.index_type import IndexStructureType -from models import DataSourceOauthBinding - - -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +from models import Account, DataSourceOauthBinding @pytest.fixture -def tenant_ctx(): - return (MagicMock(id="u1"), "tenant-1") +def current_user() -> Account: + account = Account(name="Test User", email="u1@example.com") + account.id = "u1" + return account @pytest.fixture -def patch_tenant(tenant_ctx): - with patch( - "controllers.console.datasets.data_source.current_account_with_tenant", - return_value=tenant_ctx, - ): - yield - - -@pytest.fixture -def mock_engine(): +def mock_engine() -> Iterator[None]: with patch.object( type(data_source.db), "engine", @@ -55,12 +44,12 @@ def mock_engine(): class TestDataSourceApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_success(self, app: Flask, patch_tenant): + def test_get_success(self, app: Flask) -> None: api = DataSourceApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) binding = DataSourceOauthBinding( tenant_id="tenant-1", @@ -93,7 +82,7 @@ class TestDataSourceApi: return_value=MagicMock(all=lambda: [binding]), ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 assert response["data"][0] == { @@ -120,9 +109,9 @@ class TestDataSourceApi: "link": "http://localhost/console/api/oauth/data-source/notion", } - def test_get_no_bindings(self, app: Flask, patch_tenant): + def test_get_no_bindings(self, app: Flask) -> None: api = DataSourceApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -131,14 +120,14 @@ class TestDataSourceApi: return_value=MagicMock(all=lambda: []), ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 assert response["data"] == [] - def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine): + def test_patch_enable_binding(self, app: Flask, mock_engine: None) -> None: api = DataSourceApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) binding = MagicMock(id="b1", disabled=True) @@ -152,14 +141,14 @@ class TestDataSourceApi: mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = binding - response, status = method(api, "b1", "enable") + response, status = method(api, "tenant-1", "b1", "enable") assert status == 200 assert binding.disabled is False - def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine): + def test_patch_disable_binding(self, app: Flask, mock_engine: None) -> None: api = DataSourceApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) binding = MagicMock(id="b1", disabled=False) @@ -173,14 +162,14 @@ class TestDataSourceApi: mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.execute.return_value.scalar_one_or_none.return_value = binding - response, status = method(api, "b1", "disable") + response, status = method(api, "tenant-1", "b1", "disable") assert status == 200 assert binding.disabled is True - def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine): + def test_patch_binding_not_found(self, app: Flask, mock_engine: None) -> None: api = DataSourceApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) with ( app.test_request_context("/"), @@ -191,11 +180,11 @@ class TestDataSourceApi: mock_session.execute.return_value.scalar_one_or_none.return_value = None with pytest.raises(NotFound): - method(api, "b1", "enable") + method(api, "tenant-1", "b1", "enable") - def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine): + def test_patch_enable_already_enabled(self, app: Flask, mock_engine: None) -> None: api = DataSourceApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) binding = MagicMock(id="b1", disabled=False) @@ -208,11 +197,11 @@ class TestDataSourceApi: mock_session.execute.return_value.scalar_one_or_none.return_value = binding with pytest.raises(ValueError): - method(api, "b1", "enable") + method(api, "tenant-1", "b1", "enable") - def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine): + def test_patch_disable_already_disabled(self, app: Flask, mock_engine: None) -> None: api = DataSourceApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) binding = MagicMock(id="b1", disabled=True) @@ -225,17 +214,17 @@ class TestDataSourceApi: mock_session.execute.return_value.scalar_one_or_none.return_value = binding with pytest.raises(ValueError): - method(api, "b1", "disable") + method(api, "tenant-1", "b1", "disable") class TestDataSourceNotionListApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_credential_not_found(self, app: Flask, patch_tenant): + def test_get_credential_not_found(self, app: Flask, current_user: Account) -> None: api = DataSourceNotionListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?credential_id=c1"), @@ -245,11 +234,11 @@ class TestDataSourceNotionListApi: ), ): with pytest.raises(NotFound): - method(api) + method(api, "tenant-1", current_user) - def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine): + def test_get_success_no_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None: api = DataSourceNotionListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) page = MagicMock( page_id="p1", @@ -284,13 +273,13 @@ class TestDataSourceNotionListApi: ), ), ): - response, status = method(api) + response, status = method(api, "tenant-1", current_user) assert status == 200 - def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine): + def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None: api = DataSourceNotionListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) page = MagicMock( page_id="p1", @@ -337,13 +326,13 @@ class TestDataSourceNotionListApi: mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session mock_session.scalars.return_value.all.return_value = [document] - response, status = method(api) + response, status = method(api, "tenant-1", current_user) assert status == 200 - def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine): + def test_get_invalid_dataset_type(self, app: Flask, current_user: Account, mock_engine: None) -> None: api = DataSourceNotionListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) dataset = MagicMock(data_source_type="other_type") @@ -360,17 +349,17 @@ class TestDataSourceNotionListApi: patch("controllers.console.datasets.data_source.sessionmaker"), ): with pytest.raises(ValueError): - method(api) + method(api, "tenant-1", current_user) class TestDataSourceNotionPreviewApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_preview_success(self, app: Flask, patch_tenant): + def test_get_preview_success(self, app: Flask) -> None: api = DataSourceNotionPreviewApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")]) @@ -385,21 +374,22 @@ class TestDataSourceNotionPreviewApi: return_value=extractor, ), ): - response, status = method(api, "p1", "page") + response, status = method(api, "tenant-1", "p1", "page") assert status == 200 class TestDataSourceNotionIndexingEstimateApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_post_indexing_estimate_success(self, app: Flask, patch_tenant): + def test_post_indexing_estimate_success(self, app: Flask) -> None: api = DataSourceNotionIndexingEstimateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - payload = { + empty_rules: dict[str, object] = {} + payload: dict[str, object] = { "notion_info_list": [ { "workspace_id": "w1", @@ -407,7 +397,7 @@ class TestDataSourceNotionIndexingEstimateApi: "pages": [{"page_id": "p1", "type": "page"}], } ], - "process_rule": {"rules": {}}, + "process_rule": {"rules": empty_rules}, "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", } @@ -422,19 +412,19 @@ class TestDataSourceNotionIndexingEstimateApi: return_value=MagicMock(model_dump=lambda: {"total_pages": 1}), ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 class TestDataSourceNotionDatasetSyncApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_success(self, app: Flask, patch_tenant): + def test_get_success(self, app: Flask) -> None: api = DataSourceNotionDatasetSyncApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -455,9 +445,9 @@ class TestDataSourceNotionDatasetSyncApi: assert status == 200 - def test_get_dataset_not_found(self, app: Flask, patch_tenant): + def test_get_dataset_not_found(self, app: Flask) -> None: api = DataSourceNotionDatasetSyncApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -472,12 +462,12 @@ class TestDataSourceNotionDatasetSyncApi: class TestDataSourceNotionDocumentSyncApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_success(self, app: Flask, patch_tenant): + def test_get_success(self, app: Flask) -> None: api = DataSourceNotionDocumentSyncApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -498,9 +488,9 @@ class TestDataSourceNotionDocumentSyncApi: assert status == 200 - def test_get_document_not_found(self, app: Flask, patch_tenant): + def test_get_document_not_found(self, app: Flask) -> None: api = DataSourceNotionDocumentSyncApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9c5b5ec256..e8faece89c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,3 +1,4 @@ +import inspect from unittest.mock import MagicMock, patch import pytest @@ -23,25 +24,15 @@ from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - class TestDatasourcePluginOAuthAuthorizationUrl: def test_get_success(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() - method = unwrap(api.get) + method = inspect.unwrap(api.get) user = MagicMock(id="user-1") with ( app.test_request_context("/?credential_id=cred-1"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch.object( DatasourceProviderService, "get_oauth_client", @@ -58,20 +49,17 @@ class TestDatasourcePluginOAuthAuthorizationUrl: return_value={"url": "http://auth"}, ), ): - response = method(api, "notion") + response = method(api, "tenant-1", user, "notion") assert response.status_code == 200 def test_get_no_oauth_config(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock(id="user-1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "get_oauth_client", @@ -79,20 +67,16 @@ class TestDatasourcePluginOAuthAuthorizationUrl: ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", user, "notion") def test_get_without_credential_id_sets_cookie(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() - method = unwrap(api.get) + method = inspect.unwrap(api.get) user = MagicMock(id="user-1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch.object( DatasourceProviderService, "get_oauth_client", @@ -109,7 +93,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: return_value={"url": "http://auth"}, ), ): - response = method(api, "notion") + response = method(api, "tenant-1", user, "notion") assert response.status_code == 200 assert "context_id" in response.headers.get("Set-Cookie") @@ -118,7 +102,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: class TestDatasourceOAuthCallback: def test_callback_success_new_credential(self, app: Flask): api = DatasourceOAuthCallback() - method = unwrap(api.get) + method = inspect.unwrap(api.get) oauth_response = MagicMock() oauth_response.credentials = {"token": "abc"} @@ -160,7 +144,7 @@ class TestDatasourceOAuthCallback: def test_callback_missing_context(self, app: Flask): api = DatasourceOAuthCallback() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with app.test_request_context("/"): with pytest.raises(Forbidden): @@ -168,7 +152,7 @@ class TestDatasourceOAuthCallback: def test_callback_invalid_context(self, app: Flask): api = DatasourceOAuthCallback() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?context_id=bad"), @@ -183,7 +167,7 @@ class TestDatasourceOAuthCallback: def test_callback_oauth_config_not_found(self, app: Flask): api = DatasourceOAuthCallback() - method = unwrap(api.get) + method = inspect.unwrap(api.get) context = {"user_id": "u", "tenant_id": "t"} @@ -205,7 +189,7 @@ class TestDatasourceOAuthCallback: def test_callback_reauthorize_existing_credential(self, app: Flask): api = DatasourceOAuthCallback() - method = unwrap(api.get) + method = inspect.unwrap(api.get) oauth_response = MagicMock() oauth_response.credentials = {"token": "abc"} @@ -248,7 +232,7 @@ class TestDatasourceOAuthCallback: def test_callback_context_id_from_cookie(self, app: Flask): api = DatasourceOAuthCallback() - method = unwrap(api.get) + method = inspect.unwrap(api.get) oauth_response = MagicMock() oauth_response.credentials = {"token": "abc"} @@ -292,40 +276,32 @@ class TestDatasourceOAuthCallback: class TestDatasourceAuth: def test_post_success(self, app: Flask): api = DatasourceAuth() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credentials": {"key": "val"}} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "add_datasource_api_key_provider", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 200 def test_post_invalid_credentials(self, app: Flask): api = DatasourceAuth() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credentials": {"key": "bad"}} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "add_datasource_api_key_provider", @@ -333,63 +309,53 @@ class TestDatasourceAuth: ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", "notion") def test_get_success(self, app: Flask): api = DatasourceAuth() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock(id="user-1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "list_datasource_credentials", return_value=[{"id": "1"}], ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", user, "notion") assert status == 200 assert response["result"] def test_post_missing_credentials(self, app: Flask): api = DatasourceAuth() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", "notion") def test_get_empty_list(self, app: Flask): api = DatasourceAuth() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock(id="user-1") with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "list_datasource_credentials", return_value=[], ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", user, "notion") assert status == 200 assert response["result"] == [] @@ -398,136 +364,112 @@ class TestDatasourceAuth: class TestDatasourceAuthDeleteApi: def test_delete_success(self, app: Flask): api = DatasourceAuthDeleteApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credential_id": "cred-1"} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "remove_datasource_credentials", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 200 def test_delete_missing_credential_id(self, app: Flask): api = DatasourceAuthDeleteApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", "notion") class TestDatasourceAuthUpdateApi: def test_update_success(self, app: Flask): api = DatasourceAuthUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credential_id": "id", "credentials": {"k": "v"}} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "update_datasource_credentials", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 201 def test_update_with_credentials_none(self, app: Flask): api = DatasourceAuthUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credential_id": "id", "credentials": None} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "update_datasource_credentials", return_value=None, ) as update_mock, ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") update_mock.assert_called_once() assert status == 201 def test_update_name_only(self, app: Flask): api = DatasourceAuthUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credential_id": "id", "name": "New Name"} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "update_datasource_credentials", return_value=None, ), ): - _, status = method(api, "notion") + _, status = method(api, "tenant-1", "notion") assert status == 201 def test_update_with_empty_credentials_dict(self, app: Flask): api = DatasourceAuthUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credential_id": "id", "credentials": {}} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "update_datasource_credentials", return_value=None, ) as update_mock, ): - _, status = method(api, "notion") + _, status = method(api, "tenant-1", "notion") update_mock.assert_called_once() assert status == 201 @@ -536,62 +478,50 @@ class TestDatasourceAuthUpdateApi: class TestDatasourceAuthListApi: def test_list_success(self, app: Flask): api = DatasourceAuthListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "get_all_datasource_credentials", return_value=[{"id": "1"}], ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 def test_auth_list_empty(self, app: Flask): api = DatasourceAuthListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "get_all_datasource_credentials", return_value=[], ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 assert response["result"] == [] def test_hardcode_list_empty(self, app: Flask): api = DatasourceHardCodeAuthListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "get_hard_code_datasource_credentials", return_value=[], ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 assert response["result"] == [] @@ -600,21 +530,17 @@ class TestDatasourceAuthListApi: class TestDatasourceHardCodeAuthListApi: def test_list_success(self, app: Flask): api = DatasourceHardCodeAuthListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "get_hard_code_datasource_credentials", return_value=[{"id": "1"}], ), ): - response, status = method(api) + response, status = method(api, "tenant-1") assert status == 200 @@ -622,73 +548,61 @@ class TestDatasourceHardCodeAuthListApi: class TestDatasourceAuthOauthCustomClient: def test_post_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"client_params": {}, "enable_oauth_custom_client": True} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "setup_oauth_custom_client_params", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 200 def test_delete_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "remove_oauth_custom_client_params", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 200 def test_post_empty_payload(self, app: Flask): api = DatasourceAuthOauthCustomClient() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "setup_oauth_custom_client_params", return_value=None, ), ): - _, status = method(api, "notion") + _, status = method(api, "tenant-1", "notion") assert status == 200 def test_post_disabled_flag(self, app: Flask): api = DatasourceAuthOauthCustomClient() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "client_params": {"a": 1}, @@ -698,17 +612,13 @@ class TestDatasourceAuthOauthCustomClient: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "setup_oauth_custom_client_params", return_value=None, ) as setup_mock, ): - _, status = method(api, "notion") + _, status = method(api, "tenant-1", "notion") setup_mock.assert_called_once() assert status == 200 @@ -717,72 +627,60 @@ class TestDatasourceAuthOauthCustomClient: class TestDatasourceAuthDefaultApi: def test_set_default_success(self, app: Flask): api = DatasourceAuthDefaultApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"id": "cred-1"} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "set_default_datasource_provider", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 200 def test_default_missing_id(self, app: Flask): api = DatasourceAuthDefaultApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", "notion") class TestDatasourceUpdateProviderNameApi: def test_update_name_success(self, app: Flask): api = DatasourceUpdateProviderNameApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"credential_id": "id", "name": "New Name"} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch.object( DatasourceProviderService, "update_datasource_provider_name", return_value=None, ), ): - response, status = method(api, "notion") + response, status = method(api, "tenant-1", "notion") assert status == 200 def test_update_name_too_long(self, app: Flask): api = DatasourceUpdateProviderNameApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "credential_id": "id", @@ -792,27 +690,19 @@ class TestDatasourceUpdateProviderNameApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", "notion") def test_update_name_missing_credential_id(self, app: Flask): api = DatasourceUpdateProviderNameApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"name": "Valid"} with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), ): with pytest.raises(ValueError): - method(api, "notion") + method(api, "tenant-1", "notion") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py index 5890175c97..b608e8a73f 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -11,7 +11,7 @@ from flask import Flask from controllers.console.datasets import data_source as module from controllers.console.datasets.data_source import DataSourceApi, DataSourceNotionListApi -from models import DataSourceOauthBinding +from models import Account, DataSourceOauthBinding ControllerMethod = Callable[..., tuple[dict[str, object], int]] @@ -28,13 +28,13 @@ def flask_app() -> Flask: @pytest.fixture -def tenant_context() -> tuple[MagicMock, str]: - return MagicMock(id="user-1"), "tenant-1" +def current_user() -> Account: + account = Account(name="Test User", email="user-1@example.com") + account.id = "user-1" + return account -def test_get_data_source_integrates_serializes_orm_binding( - flask_app: Flask, tenant_context: tuple[MagicMock, str] -) -> None: +def test_get_data_source_integrates_serializes_orm_binding(flask_app: Flask) -> None: binding = DataSourceOauthBinding( tenant_id="tenant-1", access_token="token", @@ -61,10 +61,9 @@ def test_get_data_source_integrates_serializes_orm_binding( with ( flask_app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=tenant_context), patch.object(module.db.session, "scalars", return_value=MagicMock(all=lambda: [binding])), ): - response, status = unwrap(DataSourceApi().get)(DataSourceApi()) + response, status = unwrap(DataSourceApi().get)(DataSourceApi(), "tenant-1") assert status == 200 assert response == { @@ -96,23 +95,18 @@ def test_get_data_source_integrates_serializes_orm_binding( } -def test_get_data_source_integrates_preserves_empty_list_when_no_binding( - flask_app: Flask, tenant_context: tuple[MagicMock, str] -) -> None: +def test_get_data_source_integrates_preserves_empty_list_when_no_binding(flask_app: Flask) -> None: with ( flask_app.test_request_context("/"), - patch.object(module, "current_account_with_tenant", return_value=tenant_context), patch.object(module.db.session, "scalars", return_value=MagicMock(all=lambda: [])), ): - response, status = unwrap(DataSourceApi().get)(DataSourceApi()) + response, status = unwrap(DataSourceApi().get)(DataSourceApi(), "tenant-1") assert status == 200 assert response == {"data": []} -def test_notion_pre_import_pages_serializes_frontend_list_shape( - flask_app: Flask, tenant_context: tuple[MagicMock, str] -) -> None: +def test_notion_pre_import_pages_serializes_frontend_list_shape(flask_app: Flask, current_user: Account) -> None: page = MagicMock( page_id="page-1", page_name="Page", @@ -137,7 +131,6 @@ def test_notion_pre_import_pages_serializes_frontend_list_shape( with ( flask_app.test_request_context("/?credential_id=credential-1"), - patch.object(module, "current_account_with_tenant", return_value=tenant_context), patch.object( module.DatasourceProviderService, "get_datasource_credentials", @@ -147,7 +140,7 @@ def test_notion_pre_import_pages_serializes_frontend_list_shape( patch.object(module, "sessionmaker"), patch("core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", return_value=runtime), ): - response, status = unwrap(DataSourceNotionListApi().get)(DataSourceNotionListApi()) + response, status = unwrap(DataSourceNotionListApi().get)(DataSourceNotionListApi(), "tenant-1", current_user) assert status == 200 assert response == { diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index 16ead7c44f..0b4ce39baf 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,3 +1,4 @@ +import inspect from unittest.mock import MagicMock, patch import pytest @@ -39,12 +40,6 @@ from models.dataset import Document as DatasetDocument from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def make_serializable_document(**overrides): attrs = { "id": "doc-1", @@ -125,11 +120,7 @@ def tenant_ctx(): @pytest.fixture def patch_tenant(tenant_ctx): - with patch( - "controllers.console.datasets.datasets_document.current_account_with_tenant", - return_value=tenant_ctx, - ): - yield + return tenant_ctx @pytest.fixture @@ -173,16 +164,18 @@ def patch_permission(): class TestGetProcessRuleApi: def test_get_default_success(self, app: Flask, patch_tenant): api = GetProcessRuleApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, _ = patch_tenant with app.test_request_context("/"): - response = method(api) + response = method(api, user) assert "rules" in response def test_get_with_document_dataset_not_found(self, app: Flask, patch_tenant): api = GetProcessRuleApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, _ = patch_tenant document = MagicMock(dataset_id="ds-1") @@ -198,13 +191,14 @@ class TestGetProcessRuleApi: ), ): with pytest.raises(NotFound): - method(api) + method(api, user) class TestDatasetDocumentListApi: def test_get_with_fetch_true_counts_segments(self, app: Flask, patch_tenant, patch_dataset, patch_permission): api = DatasetDocumentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant doc = make_serializable_document() pagination = MagicMock(items=[doc], total=1) @@ -224,7 +218,7 @@ class TestDatasetDocumentListApi: return_value=None, ), ): - resp = method(api, "ds-1") + resp = method(api, tenant_id, user, "ds-1") assert resp["data"][0]["id"] == "doc-1" assert resp["data"][0]["completed_segments"] == 2 @@ -234,7 +228,8 @@ class TestDatasetDocumentListApi: self, app: Flask, patch_tenant, patch_dataset, patch_permission ): api = DatasetDocumentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant pagination = MagicMock(items=[make_serializable_document()], total=1) @@ -253,13 +248,14 @@ class TestDatasetDocumentListApi: return_value=None, ), ): - resp = method(api, "ds-1") + resp = method(api, tenant_id, user, "ds-1") assert resp["total"] == 1 def test_get_success(self, app: Flask, patch_tenant, patch_dataset, patch_permission): api = DatasetDocumentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant pagination = MagicMock(items=[make_serializable_document()], total=1) @@ -274,7 +270,7 @@ class TestDatasetDocumentListApi: return_value=None, ), ): - response = method(api, "ds-1") + response = method(api, tenant_id, user, "ds-1") assert response["total"] == 1 assert response["data"][0]["id"] == "doc-1" @@ -283,7 +279,8 @@ class TestDatasetDocumentListApi: def test_post_success(self, app: Flask, patch_tenant, patch_dataset, patch_permission): api = DatasetDocumentListApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, _ = patch_tenant payload = {"indexing_technique": "economy"} created_dataset = make_dataset() @@ -306,7 +303,7 @@ class TestDatasetDocumentListApi: ), patch("models.dataset.db.session.scalar", return_value=0), ): - response = method(api, "ds-1") + response = method(api, user, "ds-1") assert "documents" in response assert response["dataset"]["id"] == "ds-1" @@ -318,28 +315,25 @@ class TestDatasetDocumentListApi: def test_post_forbidden(self, app: Flask): api = DatasetDocumentListApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) user = MagicMock(is_dataset_editor=False) with ( app.test_request_context("/", json={}), patch.object(type(console_ns), "payload", {}), - patch( - "controllers.console.datasets.datasets_document.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_document.DatasetService.get_dataset", return_value=MagicMock(), ), ): with pytest.raises(Forbidden): - method(api, "ds-1") + method(api, user, "ds-1") def test_get_with_fetch_true_and_invalid_fetch(self, app: Flask, patch_tenant, patch_dataset, patch_permission): api = DatasetDocumentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant pagination = MagicMock(items=[make_serializable_document()], total=1) @@ -354,13 +348,14 @@ class TestDatasetDocumentListApi: return_value=None, ), ): - response = method(api, "ds-1") + response = method(api, tenant_id, user, "ds-1") assert response["total"] == 1 def test_get_sort_hit_count(self, app: Flask, patch_tenant, patch_dataset, patch_permission): api = DatasetDocumentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant pagination = MagicMock(items=[], total=0) @@ -375,7 +370,7 @@ class TestDatasetDocumentListApi: return_value=None, ), ): - response = method(api, "ds-1") + response = method(api, tenant_id, user, "ds-1") assert response["total"] == 0 @@ -383,7 +378,8 @@ class TestDatasetDocumentListApi: class TestDatasetInitApi: def test_post_success_serializes_created_dataset_and_documents(self, app: Flask, patch_tenant): api = DatasetInitApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, tenant_id = patch_tenant payload = {"indexing_technique": "economy"} created_dataset = make_dataset() @@ -402,7 +398,7 @@ class TestDatasetInitApi: ), patch("models.dataset.db.session.scalar", return_value=0), ): - response = method(api) + response = method(api, tenant_id, user) assert response["dataset"]["id"] == "ds-1" assert response["documents"][0]["id"] == "doc-init" @@ -414,7 +410,8 @@ class TestDatasetInitApi: class TestDocumentApi: def test_get_success(self, app: Flask, patch_tenant): api = DocumentApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock(dataset_process_rule=None) @@ -426,21 +423,23 @@ class TestDocumentApi: return_value={}, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 200 def test_get_invalid_metadata(self, app: Flask, patch_tenant): api = DocumentApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant with app.test_request_context("/?metadata=wrong"), patch.object(api, "get_document", return_value=MagicMock()): with pytest.raises(InvalidMetadataError): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") def test_delete_success(self, app: Flask, patch_tenant, patch_dataset): api = DocumentApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) + user, tenant_id = patch_tenant with ( app.test_request_context("/"), @@ -454,13 +453,14 @@ class TestDocumentApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 204 def test_delete_indexing_error(self, app: Flask, patch_tenant, patch_dataset): api = DocumentApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) + user, tenant_id = patch_tenant with ( app.test_request_context("/"), @@ -475,13 +475,14 @@ class TestDocumentApi: ), ): with pytest.raises(DocumentIndexingError): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") class TestDocumentDownloadApi: def test_download_success(self, app: Flask, patch_tenant): api = DocumentDownloadApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock() @@ -493,7 +494,7 @@ class TestDocumentDownloadApi: return_value="url", ), ): - response = method(api, "ds-1", "doc-1") + response = method(api, tenant_id, user, "ds-1", "doc-1") assert response["url"] == "url" @@ -501,24 +502,21 @@ class TestDocumentDownloadApi: class TestDocumentProcessingApi: def test_processing_forbidden_when_not_editor(self, app: Flask): api = DocumentProcessingApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) user = MagicMock(is_dataset_editor=False) with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_document.current_account_with_tenant", - return_value=(user, "tenant"), - ), patch.object(api, "get_document", return_value=MagicMock()), ): with pytest.raises(Forbidden): - method(api, "ds-1", "doc-1", "pause") + method(api, "tenant-1", user, "ds-1", "doc-1", "pause") def test_resume_from_error_state(self, app: Flask, patch_tenant): api = DocumentProcessingApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, tenant_id = patch_tenant doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True) @@ -530,13 +528,14 @@ class TestDocumentProcessingApi: return_value=None, ), ): - _, status = method(api, "ds-1", "doc-1", "resume") + _, status = method(api, tenant_id, user, "ds-1", "doc-1", "resume") assert status == 200 def test_resume_success(self, app: Flask, patch_tenant): api = DocumentProcessingApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, tenant_id = patch_tenant document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True) @@ -548,13 +547,14 @@ class TestDocumentProcessingApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1", "resume") + response, status = method(api, tenant_id, user, "ds-1", "doc-1", "resume") assert status == 200 def test_pause_success(self, app: Flask, patch_tenant): api = DocumentProcessingApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, tenant_id = patch_tenant document = MagicMock(indexing_status="indexing") @@ -566,25 +566,27 @@ class TestDocumentProcessingApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1", "pause") + response, status = method(api, tenant_id, user, "ds-1", "doc-1", "pause") assert status == 200 def test_pause_invalid(self, app: Flask, patch_tenant): api = DocumentProcessingApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, tenant_id = patch_tenant document = MagicMock(indexing_status=IndexingStatus.COMPLETED) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(InvalidActionError): - method(api, "ds-1", "doc-1", "pause") + method(api, tenant_id, user, "ds-1", "doc-1", "pause") class TestDocumentMetadataApi: def test_put_metadata_schema_filtering(self, app: Flask, patch_tenant): api = DocumentMetadataApi() - method = unwrap(api.put) + method = inspect.unwrap(api.put) + user, tenant_id = patch_tenant doc = MagicMock() @@ -607,13 +609,14 @@ class TestDocumentMetadataApi: return_value=None, ), ): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") assert doc.doc_metadata == {"amount": 10} def test_put_success(self, app: Flask, patch_tenant): api = DocumentMetadataApi() - method = unwrap(api.put) + method = inspect.unwrap(api.put) + user, tenant_id = patch_tenant document = MagicMock() @@ -631,21 +634,23 @@ class TestDocumentMetadataApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 200 def test_put_invalid_payload(self, app: Flask, patch_tenant): api = DocumentMetadataApi() - method = unwrap(api.put) + method = inspect.unwrap(api.put) + user, tenant_id = patch_tenant with app.test_request_context("/", json={}), patch.object(api, "get_document", return_value=MagicMock()): with pytest.raises(ValueError): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") def test_put_invalid_doc_type(self, app: Flask, patch_tenant): api = DocumentMetadataApi() - method = unwrap(api.put) + method = inspect.unwrap(api.put) + user, tenant_id = patch_tenant payload = {"doc_type": "invalid", "doc_metadata": {}} @@ -658,13 +663,14 @@ class TestDocumentMetadataApi: ), ): with pytest.raises(ValueError): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") class TestDocumentStatusApi: def test_patch_success(self, app: Flask, patch_tenant, patch_dataset): api = DocumentStatusApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, _ = patch_tenant with ( app.test_request_context("/?document_id=doc-1"), @@ -681,13 +687,14 @@ class TestDocumentStatusApi: return_value=None, ), ): - response, status = method(api, "ds-1", "enable") + response, status = method(api, user, "ds-1", "enable") assert status == 200 def test_patch_invalid_action(self, app: Flask, patch_tenant, patch_dataset): api = DocumentStatusApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, _ = patch_tenant with ( app.test_request_context("/?document_id=doc-1"), @@ -705,13 +712,13 @@ class TestDocumentStatusApi: ), ): with pytest.raises(InvalidActionError): - method(api, "ds-1", "enable") + method(api, user, "ds-1", "enable") class TestDocumentRetryApi: def test_retry_archived_document_skipped(self, app: Flask, patch_tenant, patch_dataset): api = DocumentRetryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"document_ids": ["doc-1"]} @@ -739,7 +746,7 @@ class TestDocumentRetryApi: def test_retry_success(self, app: Flask, patch_tenant, patch_dataset): api = DocumentRetryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"document_ids": ["doc-1"]} @@ -768,7 +775,7 @@ class TestDocumentRetryApi: def test_retry_skips_completed_document(self, app: Flask, patch_tenant, patch_dataset): api = DocumentRetryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"document_ids": ["doc-1"]} @@ -795,7 +802,7 @@ class TestDocumentRetryApi: class TestDocumentPipelineExecutionLogApi: def test_get_log_success(self, app: Flask, patch_tenant, patch_dataset): api = DocumentPipelineExecutionLogApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) log = MagicMock( datasource_info="{}", @@ -823,7 +830,8 @@ class TestDocumentPipelineExecutionLogApi: class TestDocumentGenerateSummaryApi: def test_generate_summary_missing_documents(self, app: Flask, patch_tenant, patch_permission): api = DocumentGenerateSummaryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, _ = patch_tenant dataset = MagicMock( indexing_technique="high_quality", @@ -845,11 +853,12 @@ class TestDocumentGenerateSummaryApi: ), ): with pytest.raises(NotFound): - method(api, "ds-1") + method(api, user, "ds-1") def test_generate_not_enabled(self, app: Flask, patch_tenant, patch_permission): api = DocumentGenerateSummaryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, _ = patch_tenant dataset = MagicMock(indexing_technique="high_quality", summary_index_setting={"enable": False}) @@ -864,11 +873,12 @@ class TestDocumentGenerateSummaryApi: ), ): with pytest.raises(ValueError): - method(api, "ds-1") + method(api, user, "ds-1") def test_generate_summary_success_with_qa_skip(self, app: Flask, patch_tenant, patch_permission): api = DocumentGenerateSummaryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, _ = patch_tenant dataset = MagicMock( indexing_technique="high_quality", @@ -896,7 +906,7 @@ class TestDocumentGenerateSummaryApi: return_value=None, ), ): - response, status = method(api, "ds-1") + response, status = method(api, user, "ds-1") assert status == 200 @@ -904,7 +914,8 @@ class TestDocumentGenerateSummaryApi: class TestDocumentSummaryStatusApi: def test_get_success(self, app: Flask, patch_tenant, patch_permission): api = DocumentSummaryStatusApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, _ = patch_tenant with ( app.test_request_context("/"), @@ -917,7 +928,7 @@ class TestDocumentSummaryStatusApi: return_value={"total_segments": 0}, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, user, "ds-1", "doc-1") assert status == 200 @@ -925,7 +936,8 @@ class TestDocumentSummaryStatusApi: class TestDocumentIndexingEstimateApi: def test_indexing_estimate_file_not_found(self, app: Flask, patch_tenant): api = DocumentIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock( indexing_status=IndexingStatus.INDEXING, @@ -945,11 +957,12 @@ class TestDocumentIndexingEstimateApi: ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") def test_indexing_estimate_generic_exception(self, app: Flask, patch_tenant): api = DocumentIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock( indexing_status=IndexingStatus.INDEXING, @@ -982,36 +995,38 @@ class TestDocumentIndexingEstimateApi: ), ): with pytest.raises(IndexingEstimateError): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") def test_get_finished(self, app: Flask, patch_tenant): api = DocumentIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock(indexing_status=IndexingStatus.COMPLETED) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(DocumentAlreadyFinishedError): - method(api, "ds-1", "doc-1") + method(api, tenant_id, user, "ds-1", "doc-1") class TestDocumentBatchDownloadZipApi: def test_post_no_documents(self, app: Flask, patch_tenant): api = DocumentBatchDownloadZipApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, tenant_id = patch_tenant payload: dict[str, list[str]] = {"document_ids": []} with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): with pytest.raises(ValueError): - method(api, "ds-1") + method(api, tenant_id, user, "ds-1") class TestDatasetDocumentListApiDelete: def test_delete_success(self, app: Flask, patch_tenant, patch_dataset): """Test successful deletion of documents""" api = DatasetDocumentListApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with ( app.test_request_context("/?document_id=doc-1&document_id=doc-2"), @@ -1031,7 +1046,7 @@ class TestDatasetDocumentListApiDelete: def test_delete_indexing_error(self, app: Flask, patch_tenant, patch_dataset): """Test deletion with indexing error""" api = DatasetDocumentListApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with ( app.test_request_context("/?document_id=doc-1"), @@ -1050,7 +1065,7 @@ class TestDatasetDocumentListApiDelete: def test_delete_dataset_not_found(self, app: Flask, patch_tenant): """Test deletion when dataset not found""" api = DatasetDocumentListApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with ( app.test_request_context("/?document_id=doc-1"), @@ -1066,7 +1081,8 @@ class TestDatasetDocumentListApiDelete: class TestDocumentBatchIndexingEstimateApi: def test_batch_indexing_estimate_website(self, app: Flask, patch_tenant): api = DocumentBatchIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant doc = MagicMock( indexing_status=IndexingStatus.INDEXING, @@ -1089,13 +1105,14 @@ class TestDocumentBatchIndexingEstimateApi: return_value=MagicMock(model_dump=lambda: {"tokens": 2}), ), ): - resp, status = method(api, "ds-1", "batch-1") + resp, status = method(api, tenant_id, user, "ds-1", "batch-1") assert status == 200 def test_batch_indexing_estimate_notion(self, app: Flask, patch_tenant): api = DocumentBatchIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant doc = MagicMock( indexing_status=IndexingStatus.INDEXING, @@ -1117,13 +1134,14 @@ class TestDocumentBatchIndexingEstimateApi: return_value=MagicMock(model_dump=lambda: {"tokens": 1}), ), ): - resp, status = method(api, "ds-1", "batch-1") + resp, status = method(api, tenant_id, user, "ds-1", "batch-1") assert status == 200 def test_batch_estimate_unsupported_datasource(self, app: Flask, patch_tenant): api = DocumentBatchIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock( indexing_status=IndexingStatus.INDEXING, @@ -1134,22 +1152,24 @@ class TestDocumentBatchIndexingEstimateApi: with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): with pytest.raises(ValueError): - method(api, "ds-1", "batch-1") + method(api, tenant_id, user, "ds-1", "batch-1") def test_get_batch_estimate_invalid_batch(self, app: Flask, patch_tenant): """Test batch estimation with invalid batch""" api = DocumentBatchIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): with pytest.raises(NotFound): - method(api, "ds-1", "invalid-batch") + method(api, tenant_id, user, "ds-1", "invalid-batch") class TestDocumentBatchIndexingStatusApi: def test_get_batch_status_success_serializes_status_shape(self, app: Flask, patch_tenant): api = DocumentBatchIndexingStatusApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, _ = patch_tenant document = MagicMock( id="doc-1", @@ -1173,7 +1193,7 @@ class TestDocumentBatchIndexingStatusApi: side_effect=[2, 3], ), ): - response = method(api, "ds-1", "batch-1") + response = method(api, user, "ds-1", "batch-1") assert response == { "data": [ @@ -1197,17 +1217,19 @@ class TestDocumentBatchIndexingStatusApi: def test_get_batch_status_invalid_batch(self, app: Flask, patch_tenant): """Test batch status with invalid batch""" api = DocumentBatchIndexingStatusApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, _ = patch_tenant with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): with pytest.raises(NotFound): - method(api, "ds-1", "invalid-batch") + method(api, user, "ds-1", "invalid-batch") class TestDocumentIndexingStatusApi: def test_get_status_success_serializes_status_shape(self, app: Flask, patch_tenant): api = DocumentIndexingStatusApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock( id="doc-1", @@ -1231,7 +1253,7 @@ class TestDocumentIndexingStatusApi: side_effect=[1, 4], ), ): - response = method(api, "ds-1", "doc-1") + response = method(api, tenant_id, user, "ds-1", "doc-1") assert response["id"] == "doc-1" assert response["indexing_status"] == "indexing" @@ -1241,17 +1263,19 @@ class TestDocumentIndexingStatusApi: def test_get_status_document_not_found(self, app: Flask, patch_tenant): """Test getting status for non-existent document""" api = DocumentIndexingStatusApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant with app.test_request_context("/"), patch.object(api, "get_document", side_effect=NotFound()): with pytest.raises(NotFound): - method(api, "ds-1", "invalid-doc") + method(api, tenant_id, user, "ds-1", "invalid-doc") class TestDocumentRenameApi: def test_post_success_serializes_document_shape(self, app: Flask, patch_tenant): api = DocumentRenameApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, _ = patch_tenant payload = {"name": "Renamed Document"} renamed_document = make_document(id="doc-renamed", name="Renamed Document") @@ -1273,7 +1297,7 @@ class TestDocumentRenameApi: ), patch("models.dataset.db.session.scalar", return_value=0), ): - response = method(api, "ds-1", "doc-1") + response = method(api, user, "ds-1", "doc-1") assert response["id"] == "doc-renamed" assert response["name"] == "Renamed Document" @@ -1286,7 +1310,8 @@ class TestDocumentApiMetadata: def test_get_with_only_option(self, app: Flask, patch_tenant): """Test get with 'only' metadata option""" api = DocumentApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock(dataset_process_rule=None, doc_metadata_details=[]) @@ -1298,14 +1323,15 @@ class TestDocumentApiMetadata: return_value={}, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 200 def test_get_with_without_option(self, app: Flask, patch_tenant): """Test get with 'without' metadata option""" api = DocumentApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock(dataset_process_rule=None) @@ -1317,7 +1343,7 @@ class TestDocumentApiMetadata: return_value={}, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 200 @@ -1326,7 +1352,8 @@ class TestDocumentGenerateSummaryApiSuccess: def test_generate_not_enabled_high_quality(self, app: Flask, patch_tenant, patch_permission): """Test summary generation on non-high-quality dataset""" api = DocumentGenerateSummaryApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user, _ = patch_tenant dataset = MagicMock(indexing_technique="economy", summary_index_setting={"enable": True}) @@ -1341,26 +1368,28 @@ class TestDocumentGenerateSummaryApiSuccess: ), ): with pytest.raises(ValueError): - method(api, "ds-1") + method(api, user, "ds-1") class TestDocumentProcessingApiResume: def test_resume_invalid_status(self, app: Flask, patch_tenant): """Test resume on non-paused document""" api = DocumentProcessingApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) + user, tenant_id = patch_tenant document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False) with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): with pytest.raises(InvalidActionError): - method(api, "ds-1", "doc-1", "resume") + method(api, tenant_id, user, "ds-1", "doc-1", "resume") class TestDocumentPermissionCases: def test_document_batch_get_permission_denied(self, app: Flask, patch_tenant): api = DocumentBatchIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant with ( app.test_request_context("/"), @@ -1374,11 +1403,12 @@ class TestDocumentPermissionCases: ), ): with pytest.raises(Forbidden): - method(api, "ds-1", "batch-1") + method(api, tenant_id, user, "ds-1", "batch-1") def test_document_batch_get_documents_not_found(self, app: Flask, patch_tenant): api = DocumentBatchIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant with ( app.test_request_context("/"), @@ -1392,7 +1422,7 @@ class TestDocumentPermissionCases: ), patch.object(api, "get_batch_documents", return_value=None), ): - response, status = method(api, "ds-1", "batch-1") + response, status = method(api, tenant_id, user, "ds-1", "batch-1") assert status == 200 assert response == { @@ -1405,7 +1435,7 @@ class TestDocumentPermissionCases: def test_document_tenant_mismatch(self, app: Flask): api = DocumentApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) user = MagicMock(is_dataset_editor=True) document = MagicMock( @@ -1415,13 +1445,9 @@ class TestDocumentPermissionCases: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_document.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_document.DatasetService.get_dataset", - return_value=MagicMock(), # ✅ prevents real DB call + return_value=MagicMock(), ), patch( "controllers.console.datasets.datasets_document.DocumentService.get_document", @@ -1433,11 +1459,12 @@ class TestDocumentPermissionCases: ), ): with pytest.raises(Forbidden): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_process_rule_get_by_document_success(self, app: Flask, patch_tenant): api = GetProcessRuleApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, _ = patch_tenant document = MagicMock(dataset_id="ds-1") process_rule = MagicMock(mode="custom", rules_dict={"a": 1}) @@ -1461,7 +1488,7 @@ class TestDocumentPermissionCases: return_value=process_rule, ), ): - result = method(api) + result = method(api, user) if isinstance(result, tuple): response, status = result @@ -1473,16 +1500,13 @@ class TestDocumentPermissionCases: def test_process_rule_permission_denied(self, app: Flask): api = GetProcessRuleApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock(is_dataset_editor=True) document = MagicMock(dataset_id="ds-1") with ( app.test_request_context("/?document_id=doc-1"), - patch( - "controllers.console.datasets.datasets_document.current_account_with_tenant", - return_value=(MagicMock(is_dataset_editor=True), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_document.db.get_or_404", return_value=document, @@ -1497,14 +1521,15 @@ class TestDocumentPermissionCases: ), ): with pytest.raises(Forbidden): - method(api) + method(api, user) class TestDocumentListAdvancedCases: def test_document_list_with_multiple_sort_options(self, app: Flask, patch_tenant, patch_dataset, patch_permission): """Test document list with different sort options""" api = DatasetDocumentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant pagination = MagicMock(items=[make_serializable_document()], total=1) @@ -1519,14 +1544,15 @@ class TestDocumentListAdvancedCases: return_value=None, ), ): - response = method(api, "ds-1") + response = method(api, tenant_id, user, "ds-1") assert response["total"] == 1 def test_document_metadata_with_schema_validation(self, app: Flask, patch_tenant): """Test document metadata update with schema validation""" api = DocumentMetadataApi() - method = unwrap(api.put) + method = inspect.unwrap(api.put) + user, tenant_id = patch_tenant doc = MagicMock() payload = { @@ -1548,7 +1574,7 @@ class TestDocumentListAdvancedCases: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 200 assert doc.doc_metadata == {"amount": 5000, "currency": "USD"} @@ -1557,7 +1583,8 @@ class TestDocumentListAdvancedCases: class TestDocumentIndexingEdgeCases: def test_document_indexing_with_extraction_setting(self, app: Flask, patch_tenant): api = DocumentIndexingEstimateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user, tenant_id = patch_tenant document = MagicMock( indexing_status=IndexingStatus.INDEXING, @@ -1586,6 +1613,6 @@ class TestDocumentIndexingEdgeCases: return_value=MagicMock(model_dump=lambda: {"tokens": 5}), ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, tenant_id, user, "ds-1", "doc-1") assert status == 200 diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py index 23aee22d63..3b9a1dcc5e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py @@ -8,11 +8,11 @@ upload-file documents, and rejects unsupported or missing file cases. from __future__ import annotations import importlib +import inspect import sys from collections import UserDict from io import BytesIO from types import SimpleNamespace -from typing import Any from zipfile import ZipFile import pytest @@ -79,7 +79,7 @@ def _mock_document( upload_file_id: str | None, ) -> SimpleNamespace: """Build a minimal document object used by the controller.""" - data_source_info_dict: dict[str, Any] | None = None + data_source_info_dict: dict[str, object] | None = None if upload_file_id is not None: data_source_info_dict = {"upload_file_id": upload_file_id} else: @@ -97,7 +97,6 @@ def _wire_common_success_mocks( *, module, monkeypatch: pytest.MonkeyPatch, - current_tenant_id: str, document_tenant_id: str, data_source_type: str, upload_file_id: str | None, @@ -107,9 +106,6 @@ def _wire_common_success_mocks( """Patch controller dependencies to create a deterministic test environment.""" import services.dataset_service as dataset_service_module - # Make `current_account_with_tenant()` return a known user + tenant id. - monkeypatch.setattr(module, "current_account_with_tenant", lambda: (_mock_user(), current_tenant_id)) - # Return a dataset object and allow permission checks to pass. monkeypatch.setattr(module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")) monkeypatch.setattr(module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None) @@ -124,9 +120,9 @@ def _wire_common_success_mocks( monkeypatch.setattr(module.DocumentService, "get_document", lambda *_args, **_kwargs: document) # Mock UploadFile lookup via FileService batch helper. - upload_files_by_id: dict[str, Any] = {} + upload_files_by_id: dict[str, object] = {} if upload_file_exists and upload_file_id is not None: - upload_files_by_id[str(upload_file_id)] = SimpleNamespace(id=str(upload_file_id)) + upload_files_by_id[upload_file_id] = SimpleNamespace(id=upload_file_id) monkeypatch.setattr(module.FileService, "get_upload_files_by_ids", lambda *_args, **_kwargs: upload_files_by_id) # Mock signing helper so the returned URL is deterministic. @@ -153,8 +149,6 @@ def test_batch_download_zip_returns_send_file( ) -> None: """Ensure batch ZIP download returns a zip attachment via `send_file`.""" - # Arrange common permission mocks. - monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123")) monkeypatch.setattr( datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") ) @@ -204,7 +198,8 @@ def test_batch_download_zip_returns_send_file( json={"document_ids": ["11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222"]}, ): api = datasets_document_module.DocumentBatchDownloadZipApi() - result = api.post(dataset_id="ds-1") + method = inspect.unwrap(api.post) + result = method(api, "tenant-123", _mock_user(), dataset_id="ds-1") # Assert: we returned via send_file with correct mime type and attachment. assert result["_send_file_kwargs"]["mimetype"] == "application/zip" @@ -222,7 +217,6 @@ def test_batch_download_zip_response_is_openable_zip( """Ensure the real Flask `send_file` response body is a valid ZIP that can be opened.""" # Arrange: same controller mocks as the lightweight send_file test, but we keep the real `send_file`. - monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123")) monkeypatch.setattr( datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") ) @@ -270,7 +264,8 @@ def test_batch_download_zip_response_is_openable_zip( json={"document_ids": ["33333333-3333-3333-3333-333333333333", "44444444-4444-4444-4444-444444444444"]}, ): api = datasets_document_module.DocumentBatchDownloadZipApi() - response = api.post(dataset_id="ds-1") + method = inspect.unwrap(api.post) + response = method(api, "tenant-123", _mock_user(), dataset_id="ds-1") # Assert: response body is a valid ZIP and contains the expected entries. response.direct_passthrough = False @@ -288,7 +283,6 @@ def test_batch_download_zip_rejects_non_upload_file_document( ) -> None: """Ensure batch ZIP download rejects non upload-file documents.""" - monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123")) monkeypatch.setattr( datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") ) @@ -314,8 +308,9 @@ def test_batch_download_zip_rejects_non_upload_file_document( json={"document_ids": ["55555555-5555-5555-5555-555555555555"]}, ): api = datasets_document_module.DocumentBatchDownloadZipApi() + method = inspect.unwrap(api.post) with pytest.raises(NotFound): - api.post(dataset_id="ds-1") + method(api, "tenant-123", _mock_user(), dataset_id="ds-1") def test_document_download_returns_url_for_upload_file_document( @@ -326,7 +321,6 @@ def test_document_download_returns_url_for_upload_file_document( _wire_common_success_mocks( module=datasets_document_module, monkeypatch=monkeypatch, - current_tenant_id="tenant-123", document_tenant_id="tenant-123", data_source_type="upload_file", upload_file_id="file-123", @@ -337,7 +331,8 @@ def test_document_download_returns_url_for_upload_file_document( # Build a request context then call the resource method directly. with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): api = datasets_document_module.DocumentDownloadApi() - result = api.get(dataset_id="ds-1", document_id="doc-1") + method = inspect.unwrap(api.get) + result = method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1") assert result == {"url": "https://example.com/signed"} @@ -350,7 +345,6 @@ def test_document_download_rejects_non_upload_file_document( _wire_common_success_mocks( module=datasets_document_module, monkeypatch=monkeypatch, - current_tenant_id="tenant-123", document_tenant_id="tenant-123", data_source_type="website_crawl", upload_file_id="file-123", @@ -360,8 +354,9 @@ def test_document_download_rejects_non_upload_file_document( with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): api = datasets_document_module.DocumentDownloadApi() + method = inspect.unwrap(api.get) with pytest.raises(NotFound): - api.get(dataset_id="ds-1", document_id="doc-1") + method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1") def test_document_download_rejects_missing_upload_file_id( @@ -372,7 +367,6 @@ def test_document_download_rejects_missing_upload_file_id( _wire_common_success_mocks( module=datasets_document_module, monkeypatch=monkeypatch, - current_tenant_id="tenant-123", document_tenant_id="tenant-123", data_source_type="upload_file", upload_file_id=None, @@ -382,8 +376,9 @@ def test_document_download_rejects_missing_upload_file_id( with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): api = datasets_document_module.DocumentDownloadApi() + method = inspect.unwrap(api.get) with pytest.raises(NotFound): - api.get(dataset_id="ds-1", document_id="doc-1") + method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1") def test_document_download_rejects_when_upload_file_record_missing( @@ -394,7 +389,6 @@ def test_document_download_rejects_when_upload_file_record_missing( _wire_common_success_mocks( module=datasets_document_module, monkeypatch=monkeypatch, - current_tenant_id="tenant-123", document_tenant_id="tenant-123", data_source_type="upload_file", upload_file_id="file-123", @@ -404,8 +398,9 @@ def test_document_download_rejects_when_upload_file_record_missing( with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): api = datasets_document_module.DocumentDownloadApi() + method = inspect.unwrap(api.get) with pytest.raises(NotFound): - api.get(dataset_id="ds-1", document_id="doc-1") + method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1") def test_document_download_rejects_tenant_mismatch( @@ -416,7 +411,6 @@ def test_document_download_rejects_tenant_mismatch( _wire_common_success_mocks( module=datasets_document_module, monkeypatch=monkeypatch, - current_tenant_id="tenant-123", document_tenant_id="tenant-999", data_source_type="upload_file", upload_file_id="file-123", @@ -426,5 +420,6 @@ def test_document_download_rejects_tenant_mismatch( with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"): api = datasets_document_module.DocumentDownloadApi() + method = inspect.unwrap(api.get) with pytest.raises(Forbidden): - api.get(dataset_id="ds-1", document_id="doc-1") + method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index a07c110ed9..09d4da9474 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -1,3 +1,4 @@ +import inspect from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -34,12 +35,6 @@ from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDelete from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _segment(): segment = DocumentSegment( tenant_id="tenant-1", @@ -129,10 +124,11 @@ def test_segment_response_with_summary(): class TestDatasetDocumentSegmentListApi: def test_get_success(self, app: Flask): api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) dataset = MagicMock() document = MagicMock() + user = MagicMock() segment = _segment() @@ -143,10 +139,6 @@ class TestDatasetDocumentSegmentListApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -170,40 +162,34 @@ class TestDatasetDocumentSegmentListApi: patch("models.dataset.db.session.scalar", return_value=None), patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1") assert status == 200 def test_get_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock() with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_get_permission_denied(self, app: Flask): api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) dataset = MagicMock() + user = MagicMock() with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -214,13 +200,13 @@ class TestDatasetDocumentSegmentListApi: ), ): with pytest.raises(Forbidden): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") class TestDatasetDocumentSegmentApi: def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) user = MagicMock() user.is_dataset_editor = True @@ -233,10 +219,6 @@ class TestDatasetDocumentSegmentApi: with ( app.test_request_context("/?segment_id=s1&segment_id=s2"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -258,14 +240,14 @@ class TestDatasetDocumentSegmentApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1", "enable") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "enable") assert status == 200 assert response["result"] == "success" def test_patch_document_indexing_in_progress(self, app: Flask): api = DatasetDocumentSegmentApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) user = MagicMock() user.is_dataset_editor = True @@ -278,10 +260,6 @@ class TestDatasetDocumentSegmentApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -304,11 +282,11 @@ class TestDatasetDocumentSegmentApi: ), ): with pytest.raises(InvalidActionError): - method(api, "ds-1", "doc-1", "disable") + method(api, "tenant-1", user, "ds-1", "doc-1", "disable") def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) user = MagicMock(is_dataset_editor=True) @@ -322,10 +300,6 @@ class TestDatasetDocumentSegmentApi: with ( app.test_request_context("/?segment_id=s1"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -348,11 +322,11 @@ class TestDatasetDocumentSegmentApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, "ds-1", "doc-1", "enable") + method(api, "tenant-1", user, "ds-1", "doc-1", "enable") def test_patch_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) user = MagicMock(is_dataset_editor=True) @@ -366,10 +340,6 @@ class TestDatasetDocumentSegmentApi: with ( app.test_request_context("/?segment_id=s1"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -392,13 +362,13 @@ class TestDatasetDocumentSegmentApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, "ds-1", "doc-1", "enable") + method(api, "tenant-1", user, "ds-1", "doc-1", "enable") class TestDatasetDocumentSegmentAddApi: def test_post_success(self, app: Flask): api = DatasetDocumentSegmentAddApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"content": "hello"} @@ -416,10 +386,6 @@ class TestDatasetDocumentSegmentAddApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -447,14 +413,14 @@ class TestDatasetDocumentSegmentAddApi: patch("models.dataset.db.session.scalar", return_value=None), patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1") assert status == 200 assert response["data"]["id"] == "seg-1" def test_post_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentAddApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"content": "x"} @@ -471,10 +437,6 @@ class TestDatasetDocumentSegmentAddApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -489,11 +451,11 @@ class TestDatasetDocumentSegmentAddApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_post_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentAddApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"content": "x"} @@ -510,10 +472,6 @@ class TestDatasetDocumentSegmentAddApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -528,13 +486,13 @@ class TestDatasetDocumentSegmentAddApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") class TestDatasetDocumentSegmentUpdateApi: def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) payload = {"content": "updated"} @@ -552,10 +510,6 @@ class TestDatasetDocumentSegmentUpdateApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -586,14 +540,14 @@ class TestDatasetDocumentSegmentUpdateApi: ), patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))), ): - response, status = method(api, "ds-1", "doc-1", "seg-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1") assert status == 200 assert "data" in response def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) payload = {"content": "x"} @@ -610,10 +564,6 @@ class TestDatasetDocumentSegmentUpdateApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -632,26 +582,23 @@ class TestDatasetDocumentSegmentUpdateApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, "ds-1", "doc-1", "seg-1") + method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1") class TestDatasetDocumentSegmentBatchImportApi: def test_post_success(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"upload_file_id": "file-1"} upload_file = MagicMock(spec=UploadFile) upload_file.name = "test.csv" + user = MagicMock(id="u1") with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -673,45 +620,39 @@ class TestDatasetDocumentSegmentBatchImportApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1") assert status == 200 assert response["job_status"] == "waiting" def test_post_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"upload_file_id": "file-1"} + user = MagicMock(id="u1") with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_post_document_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"upload_file_id": "file-1"} + user = MagicMock(id="u1") with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -722,21 +663,18 @@ class TestDatasetDocumentSegmentBatchImportApi: ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_post_upload_file_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"upload_file_id": "file-1"} + user = MagicMock(id="u1") with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -751,24 +689,21 @@ class TestDatasetDocumentSegmentBatchImportApi: ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_post_invalid_file_type(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"upload_file_id": "file-1"} upload_file = MagicMock() upload_file.name = "test.txt" + user = MagicMock(id="u1") with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -783,24 +718,21 @@ class TestDatasetDocumentSegmentBatchImportApi: ), ): with pytest.raises(ValueError): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_post_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"upload_file_id": "file-1"} upload_file = MagicMock() upload_file.name = "test.csv" + user = MagicMock(id="u1") with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(id="u1"), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -818,14 +750,14 @@ class TestDatasetDocumentSegmentBatchImportApi: side_effect=Exception("redis down"), ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1") assert status == 500 assert "error" in response def test_get_job_not_found_in_redis(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -840,23 +772,19 @@ class TestDatasetDocumentSegmentBatchImportApi: class TestChildChunkAddApi: def test_patch_documents_batch_update_payload(self): - api_doc = unwrap(ChildChunkAddApi.patch).__apidoc__ + api_doc = getattr(ChildChunkAddApi.patch, "__apidoc__") # noqa: B009 expected_model = ChildChunkBatchUpdatePayload.__name__ assert [model.name for model in api_doc["expect"]] == [expected_model] def test_get_uses_default_pagination_for_malformed_ints(self, app: Flask): api = ChildChunkAddApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) pagination = MagicMock(items=[], total=0, pages=0) with ( app.test_request_context("/?page=bad&limit="), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -878,7 +806,7 @@ class TestChildChunkAddApi: return_value=pagination, ) as get_child_chunks, ): - response, status = method(api, "ds-1", "doc-1", "seg-1") + response, status = method(api, "tenant-1", "ds-1", "doc-1", "seg-1") assert status == 200 assert response["page"] == 1 @@ -887,7 +815,7 @@ class TestChildChunkAddApi: def test_post_success(self, app: Flask): api = ChildChunkAddApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"content": "child"} @@ -904,10 +832,6 @@ class TestChildChunkAddApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -929,14 +853,14 @@ class TestChildChunkAddApi: return_value=child_chunk, ), ): - response, status = method(api, "ds-1", "doc-1", "seg-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1") assert status == 200 assert response["data"]["id"] == "cc-1" def test_post_child_chunk_indexing_error(self, app: Flask): api = ChildChunkAddApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"content": "child"} @@ -949,10 +873,6 @@ class TestChildChunkAddApi: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -975,13 +895,13 @@ class TestChildChunkAddApi: ), ): with pytest.raises(ChildChunkIndexingError): - method(api, "ds-1", "doc-1", "seg-1") + method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1") class TestChildChunkUpdateApi: def test_delete_success(self, app: Flask): api = ChildChunkUpdateApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) user = MagicMock() user.is_dataset_editor = True @@ -993,10 +913,6 @@ class TestChildChunkUpdateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1018,14 +934,14 @@ class TestChildChunkUpdateApi: return_value=None, ), ): - response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1", "cc-1") assert status == 204 assert response == "" def test_delete_child_chunk_index_error(self, app: Flask): api = ChildChunkUpdateApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) user = MagicMock(is_dataset_editor=True) @@ -1036,10 +952,6 @@ class TestChildChunkUpdateApi: with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1062,16 +974,17 @@ class TestChildChunkUpdateApi: ), ): with pytest.raises(ChildChunkDeleteIndexError): - method(api, "ds-1", "doc-1", "seg-1", "cc-1") + method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1", "cc-1") class TestSegmentListAdvancedCases: def test_segment_list_with_keyword_filter(self, app: Flask): api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) dataset = MagicMock() document = MagicMock() + user = MagicMock() segment = _segment() @@ -1079,10 +992,6 @@ class TestSegmentListAdvancedCases: with ( app.test_request_context("/?keyword=test"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1106,7 +1015,7 @@ class TestSegmentListAdvancedCases: patch("models.dataset.db.session.scalar", return_value=None), patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))), ): - result = method(api, "ds-1", "doc-1") + result = method(api, "tenant-1", user, "ds-1", "doc-1") if isinstance(result, tuple): response, status = result @@ -1118,18 +1027,15 @@ class TestSegmentListAdvancedCases: def test_segment_list_postgres_keyword_filter_handles_scalar_keywords(self, app: Flask): api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) dataset = MagicMock() document = MagicMock() + user = MagicMock() pagination = MagicMock(items=[], total=0, pages=0) with ( app.test_request_context("/?keyword=test"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "11111111-1111-1111-1111-111111111111"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1151,7 +1057,13 @@ class TestSegmentListAdvancedCases: return_value=pagination, ) as paginate_mock, ): - method(api, "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333") + method( + api, + "11111111-1111-1111-1111-111111111111", + user, + "22222222-2222-2222-2222-222222222222", + "33333333-3333-3333-3333-333333333333", + ) query = paginate_mock.call_args.kwargs["select"] sql = str(query.compile(compile_kwargs={"literal_binds": True})) @@ -1161,14 +1073,11 @@ class TestSegmentListAdvancedCases: def test_segment_list_permission_denied(self, app: Flask): """Test segment list with permission denied""" api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock() with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=MagicMock(), @@ -1179,33 +1088,30 @@ class TestSegmentListAdvancedCases: ), ): with pytest.raises(Forbidden): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_segment_list_dataset_not_found(self, app: Flask): """Test segment list with dataset not found""" api = DatasetDocumentSegmentListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = MagicMock() with ( app.test_request_context("/"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") class TestSegmentOperationCases: def test_segment_add_with_provider_token_error(self, app: Flask): """Test segment add with provider token not initialized""" api = DatasetDocumentSegmentAddApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) user = MagicMock(is_dataset_editor=True) dataset = MagicMock() @@ -1216,10 +1122,6 @@ class TestSegmentOperationCases: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1238,12 +1140,12 @@ class TestSegmentOperationCases: ), ): with pytest.raises(ProviderTokenNotInitError): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_batch_import_with_document_not_found(self, app: Flask): """Test batch import with document not found""" api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) user = MagicMock(is_dataset_editor=True) dataset = MagicMock() @@ -1253,10 +1155,6 @@ class TestSegmentOperationCases: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1267,12 +1165,12 @@ class TestSegmentOperationCases: ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_batch_import_with_invalid_file(self, app: Flask): """Test batch import with invalid file type""" api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) user = MagicMock(is_dataset_editor=True) dataset = MagicMock() @@ -1284,10 +1182,6 @@ class TestSegmentOperationCases: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1302,11 +1196,11 @@ class TestSegmentOperationCases: ), ): with pytest.raises(NotFound): - method(api, "ds-1", "doc-1") + method(api, "tenant-1", user, "ds-1", "doc-1") def test_batch_import_with_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) user = MagicMock(is_dataset_editor=True) dataset = MagicMock() @@ -1319,10 +1213,6 @@ class TestSegmentOperationCases: with ( app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", return_value=dataset, @@ -1344,23 +1234,17 @@ class TestSegmentOperationCases: side_effect=Exception("Task failed"), ), ): - response, status = method(api, "ds-1", "doc-1") + response, status = method(api, "tenant-1", user, "ds-1", "doc-1") assert status == 500 assert "error" in response def test_batch_import_get_job_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() - method = unwrap(api.get) - - user = MagicMock(is_dataset_editor=True) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?job_id=invalid-job"), - patch( - "controllers.console.datasets.datasets_segments.current_account_with_tenant", - return_value=(user, "tenant-1"), - ), patch( "controllers.console.datasets.datasets_segments.redis_client.get", return_value=None, diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 3e76e6c21a..7cb41dc99c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -1,4 +1,4 @@ -from importlib import import_module +import inspect from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -16,50 +16,32 @@ from controllers.console.datasets.external import ( ExternalDatasetCreateApi, ExternalKnowledgeHitTestingApi, ) +from models.account import Account, TenantAccountRole from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService from services.knowledge_service import ExternalDatasetTestService -external_controller = import_module("controllers.console.datasets.external") - - -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - @pytest.fixture -def app(): +def app() -> Flask: app = Flask("test_external_dataset") app.config["TESTING"] = True return app @pytest.fixture -def current_user(): - user = MagicMock() +def current_user() -> Account: + user = Account(name="Test User", email="user-1@example.com") user.id = "user-1" - user.is_dataset_editor = True - user.has_edit_permission = True - user.is_dataset_operator = True + user.role = TenantAccountRole.EDITOR return user -@pytest.fixture(autouse=True) -def mock_auth(monkeypatch: pytest.MonkeyPatch, current_user): - monkeypatch.setattr( - external_controller, - "current_account_with_tenant", - lambda: (current_user, "tenant-1"), - ) - - class TestExternalApiTemplateListApi: def test_get_success(self, app: Flask): api = ExternalApiTemplateListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) api_item = MagicMock() api_item.to_dict.return_value = {"id": "1"} @@ -79,10 +61,10 @@ class TestExternalApiTemplateListApi: assert resp["data"][0]["id"] == "1" get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) - def test_post_forbidden(self, app: Flask, current_user): - current_user.is_dataset_editor = False + def test_post_forbidden(self, app: Flask, current_user: Account): + current_user.role = TenantAccountRole.NORMAL api = ExternalApiTemplateListApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"name": "x", "settings": {"k": "v"}} @@ -92,11 +74,11 @@ class TestExternalApiTemplateListApi: patch.object(ExternalDatasetService, "validate_api_list"), ): with pytest.raises(Forbidden): - method(api) + method(api, "tenant-1", current_user) - def test_post_duplicate_name(self, app: Flask): + def test_post_duplicate_name(self, app: Flask, current_user: Account): api = ExternalApiTemplateListApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"name": "x", "settings": {"k": "v"}} @@ -111,13 +93,13 @@ class TestExternalApiTemplateListApi: ), ): with pytest.raises(DatasetNameDuplicateError): - method(api) + method(api, "tenant-1", current_user) class TestExternalApiTemplateApi: def test_get_not_found(self, app: Flask): api = ExternalApiTemplateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -128,24 +110,23 @@ class TestExternalApiTemplateApi: ), ): with pytest.raises(NotFound): - method(api, "api-id") + method(api, "tenant-1", "api-id") - def test_delete_forbidden(self, app: Flask, current_user): - current_user.has_edit_permission = False - current_user.is_dataset_operator = False + def test_delete_forbidden(self, app: Flask, current_user: Account): + current_user.role = TenantAccountRole.NORMAL api = ExternalApiTemplateApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with app.test_request_context("/"): with pytest.raises(Forbidden): - method(api, "api-id") + method(api, "tenant-1", current_user, "api-id") class TestExternalApiUseCheckApi: def test_get_scopes_usage_check_to_current_tenant(self, app: Flask): api = ExternalApiUseCheckApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/"), @@ -155,7 +136,7 @@ class TestExternalApiUseCheckApi: return_value=(True, 2), ) as mock_use_check, ): - response, status = method(api, "api-id") + response, status = method(api, "tenant-1", "api-id") assert status == 200 assert response == {"is_using": True, "count": 2} @@ -163,9 +144,9 @@ class TestExternalApiUseCheckApi: class TestExternalDatasetCreateApi: - def test_create_success(self, app: Flask): + def test_create_success(self, app: Flask, current_user: Account): api = ExternalDatasetCreateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "external_knowledge_api_id": "api", @@ -203,14 +184,14 @@ class TestExternalDatasetCreateApi: return_value=dataset, ), ): - _, status = method(api) + _, status = method(api, "tenant-1", current_user) assert status == 201 - def test_create_forbidden(self, app: Flask, current_user): - current_user.is_dataset_editor = False + def test_create_forbidden(self, app: Flask, current_user: Account): + current_user.role = TenantAccountRole.NORMAL api = ExternalDatasetCreateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "external_knowledge_api_id": "api", @@ -223,13 +204,13 @@ class TestExternalDatasetCreateApi: patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), ): with pytest.raises(Forbidden): - method(api) + method(api, "tenant-1", current_user) class TestExternalKnowledgeHitTestingApi: - def test_hit_testing_dataset_not_found(self, app: Flask): + def test_hit_testing_dataset_not_found(self, app: Flask, current_user: Account): api = ExternalKnowledgeHitTestingApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) with ( app.test_request_context("/"), @@ -240,11 +221,11 @@ class TestExternalKnowledgeHitTestingApi: ), ): with pytest.raises(NotFound): - method(api, "dataset-id") + method(api, current_user, "dataset-id") - def test_hit_testing_success(self, app: Flask): + def test_hit_testing_success(self, app: Flask, current_user: Account): api = ExternalKnowledgeHitTestingApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"query": "hello"} @@ -261,7 +242,7 @@ class TestExternalKnowledgeHitTestingApi: return_value={"ok": True}, ), ): - resp = method(api, "dataset-id") + resp = method(api, current_user, "dataset-id") assert resp["ok"] is True @@ -269,7 +250,7 @@ class TestExternalKnowledgeHitTestingApi: class TestBedrockRetrievalApi: def test_bedrock_retrieval(self, app: Flask): api = BedrockRetrievalApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "retrieval_setting": {}, @@ -293,9 +274,9 @@ class TestBedrockRetrievalApi: class TestExternalApiTemplateListApiAdvanced: - def test_post_duplicate_name_error(self, app: Flask, mock_auth, current_user): + def test_post_duplicate_name_error(self, app: Flask, current_user: Account): api = ExternalApiTemplateListApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"name": "duplicate_api", "settings": {"key": "value"}} @@ -309,11 +290,11 @@ class TestExternalApiTemplateListApiAdvanced: ), ): with pytest.raises(DatasetNameDuplicateError): - method(api) + method(api, "tenant-1", current_user) - def test_get_with_pagination(self, app: Flask, mock_auth, current_user): + def test_get_with_pagination(self, app: Flask): api = ExternalApiTemplateListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) templates = [MagicMock(id=f"api-{i}") for i in range(3)] @@ -333,12 +314,12 @@ class TestExternalApiTemplateListApiAdvanced: class TestExternalDatasetCreateApiAdvanced: - def test_create_forbidden(self, app: Flask, mock_auth, current_user): + def test_create_forbidden(self, app: Flask, current_user: Account): """Test creating external dataset without permission""" api = ExternalDatasetCreateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - current_user.is_dataset_editor = False + current_user.role = TenantAccountRole.NORMAL payload = { "external_knowledge_api_id": "api-1", @@ -349,14 +330,14 @@ class TestExternalDatasetCreateApiAdvanced: with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): with pytest.raises(Forbidden): - method(api) + method(api, "tenant-1", current_user) class TestExternalKnowledgeHitTestingApiAdvanced: - def test_hit_testing_dataset_not_found(self, app: Flask, mock_auth, current_user): + def test_hit_testing_dataset_not_found(self, app: Flask, current_user: Account): """Test hit testing on non-existent dataset""" api = ExternalKnowledgeHitTestingApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "query": "test query", @@ -372,11 +353,11 @@ class TestExternalKnowledgeHitTestingApiAdvanced: ), ): with pytest.raises(NotFound): - method(api, "ds-1") + method(api, current_user, "ds-1") - def test_hit_testing_with_custom_retrieval_model(self, app: Flask, mock_auth, current_user): + def test_hit_testing_with_custom_retrieval_model(self, app: Flask, current_user: Account): api = ExternalKnowledgeHitTestingApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) dataset = MagicMock() payload = { @@ -398,15 +379,15 @@ class TestExternalKnowledgeHitTestingApiAdvanced: return_value={"results": []}, ), ): - resp = method(api, "ds-1") + resp = method(api, current_user, "ds-1") assert resp["results"] == [] class TestBedrockRetrievalApiAdvanced: - def test_bedrock_retrieval_with_invalid_setting(self, app: Flask, mock_auth, current_user): + def test_bedrock_retrieval_with_invalid_setting(self, app: Flask): api = BedrockRetrievalApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "retrieval_setting": {}, diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 95d7493b71..c39d0930be 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -1,8 +1,9 @@ +import inspect from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest -from flask import Flask, g +from flask import Flask from controllers.console.workspace.account import ( AccountDeleteUpdateFeedbackApi, @@ -11,7 +12,7 @@ from controllers.console.workspace.account import ( ChangeEmailSendEmailApi, CheckEmailUnique, ) -from models import Account, AccountStatus +from models import Account, AccountStatus, Tenant from services.account_service import AccountService from services.entities.auth_entities import ( ChangeEmailNewEmailToken, @@ -26,12 +27,16 @@ def app(): app = Flask(__name__) app.config["TESTING"] = True app.config["RESTX_MASK_HEADER"] = "X-Fields" - app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None) + setattr(app, "login_manager", SimpleNamespace(load_user_from_request_context=lambda: None)) # noqa: B010 return app -def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account: - tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id") +def _build_account(email: str, account_id: str = "acc", tenant: Tenant | None = None) -> Account: + if tenant is None: + tenant_obj = Tenant(name="Tenant") + tenant_obj.id = "tenant-id" + else: + tenant_obj = tenant account = Account(name=account_id, email=email) account.email = email account.id = account_id @@ -40,11 +45,6 @@ def _build_account(email: str, account_id: str = "acc", tenant: object | None = return account -def _set_logged_in_user(account: Account): - g._login_user = account - g._current_tenant = account.current_tenant - - def _build_change_email_token( phase: str, *, @@ -71,63 +71,44 @@ def _build_change_email_token( class TestChangeEmailSend: - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.send_change_email_email") @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_old_email_phase_when_request_email_does_not_match_current_user( self, - mock_features, - mock_csrf, mock_extract_ip, mock_is_ip_limit, mock_send_email, - mock_current_account, - mock_db, app: Flask, ): from controllers.console.auth.error import InvalidEmailError - mock_features.return_value = SimpleNamespace(enable_change_email=True) - mock_current_account.return_value = (_build_account("current@example.com", "acc1"), None) + current_user = _build_account("current@example.com", "acc1") with app.test_request_context( "/account/change-email", method="POST", json={"email": "other@example.com", "language": "en-US", "phase": "old_email"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + method = inspect.unwrap(ChangeEmailSendEmailApi().post) with pytest.raises(InvalidEmailError): - ChangeEmailSendEmailApi().post() + method(ChangeEmailSendEmailApi(), current_user) mock_send_email.assert_not_called() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.send_change_email_email") @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_normalize_new_email_phase( self, - mock_features, - mock_csrf, mock_extract_ip, mock_is_ip_limit, mock_send_email, mock_get_change_data, - mock_current_account, - mock_db, app: Flask, ): - mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") - mock_current_account.return_value = (mock_account, None) mock_get_change_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, account_id="acc1", @@ -141,8 +122,9 @@ class TestChangeEmailSend: method="POST", json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) - response = ChangeEmailSendEmailApi().post() + api = ChangeEmailSendEmailApi() + method = inspect.unwrap(api.post) + response = method(api, mock_account) assert response == {"result": "success", "data": "token-abc"} mock_send_email.assert_called_once_with( @@ -154,34 +136,23 @@ class TestChangeEmailSend: ) mock_extract_ip.assert_called_once() mock_is_ip_limit.assert_called_once_with("127.0.0.1") - mock_csrf.assert_called_once() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.send_change_email_email") @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified( self, - mock_features, - mock_csrf, mock_extract_ip, mock_is_ip_limit, mock_send_email, mock_get_change_data, - mock_current_account, - mock_db, app: Flask, ): """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") - mock_current_account.return_value = (mock_account, None) mock_get_change_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_OLD, account_id="acc1", @@ -194,37 +165,28 @@ class TestChangeEmailSend: method="POST", json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailSendEmailApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailSendEmailApi().post() + method(api, mock_account) mock_send_email.assert_not_called() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.send_change_email_email") @patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False) @patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user( self, - mock_features, - mock_csrf, mock_extract_ip, mock_is_ip_limit, mock_send_email, mock_get_change_data, - mock_current_account, - mock_db, app: Flask, ): from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") - mock_current_account.return_value = (mock_account, None) mock_get_change_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, account_id="other-account", @@ -237,41 +199,32 @@ class TestChangeEmailSend: method="POST", json={"email": "new@example.com", "language": "en-US", "phase": "new_email", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailSendEmailApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailSendEmailApi().post() + method(api, mock_account) mock_send_email.assert_not_called() class TestChangeEmailValidity: - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_validate_with_normalized_email( self, - mock_features, - mock_csrf, mock_is_rate_limit, mock_get_data, mock_add_rate, mock_revoke_token, mock_generate_token, mock_reset_rate, - mock_current_account, - mock_db, app: Flask, ): - mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") - mock_current_account.return_value = (mock_account, None) mock_is_rate_limit.return_value = False mock_get_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_OLD, @@ -286,8 +239,9 @@ class TestChangeEmailValidity: method="POST", json={"email": "User@Example.com", "code": "1234", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) - response = ChangeEmailCheckApi().post() + api = ChangeEmailCheckApi() + method = inspect.unwrap(api.post) + response = method(api, mock_account) assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"} mock_is_rate_limit.assert_called_once_with("user@example.com") @@ -303,34 +257,24 @@ class TestChangeEmailValidity: mock_account, ) mock_reset_rate.assert_called_once_with("user@example.com") - mock_csrf.assert_called_once() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_upgrade_new_phase_token_to_new_verified( self, - mock_features, - mock_csrf, mock_is_rate_limit, mock_get_data, mock_add_rate, mock_revoke_token, mock_generate_token, mock_reset_rate, - mock_current_account, - mock_db, app: Flask, ): - mock_features.return_value = SimpleNamespace(enable_change_email=True) - mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + current_user = _build_account("old@example.com", "acc") mock_is_rate_limit.return_value = False mock_get_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_NEW, @@ -345,8 +289,9 @@ class TestChangeEmailValidity: method="POST", json={"email": "new@example.com", "code": "1234", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) - response = ChangeEmailCheckApi().post() + api = ChangeEmailCheckApi() + method = inspect.unwrap(api.post) + response = method(api, current_user) assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"} mock_generate_token.assert_called_once_with( @@ -356,37 +301,28 @@ class TestChangeEmailValidity: email="new@example.com", old_email="old@example.com", ), - mock_current_account.return_value[0], + current_user, ) - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_validity_when_token_is_already_verified( self, - mock_features, - mock_csrf, mock_is_rate_limit, mock_get_data, mock_add_rate, mock_revoke_token, mock_generate_token, mock_reset_rate, - mock_current_account, - mock_db, app: Flask, ): from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) - mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + current_user = _build_account("old@example.com", "acc") mock_is_rate_limit.return_value = False mock_get_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED, @@ -400,41 +336,33 @@ class TestChangeEmailValidity: method="POST", json={"email": "old@example.com", "code": "1234", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailCheckApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailCheckApi().post() + method(api, current_user) mock_revoke_token.assert_not_called() mock_generate_token.assert_not_called() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.generate_change_email_token") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_validity_when_token_account_id_does_not_match_current_user( self, - mock_features, - mock_csrf, mock_is_rate_limit, mock_get_data, mock_add_rate, mock_revoke_token, mock_generate_token, mock_reset_rate, - mock_current_account, - mock_db, app: Flask, ): from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) - mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) + current_user = _build_account("old@example.com", "acc") mock_is_rate_limit.return_value = False mock_get_data.return_value = _build_change_email_token( AccountService.CHANGE_EMAIL_PHASE_NEW, @@ -448,42 +376,33 @@ class TestChangeEmailValidity: method="POST", json={"email": "new@example.com", "code": "1234", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailCheckApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailCheckApi().post() + method(api, current_user) mock_revoke_token.assert_not_called() mock_generate_token.assert_not_called() class TestChangeEmailReset: - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") @patch("controllers.console.workspace.account.AccountService.update_account_email") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_normalize_new_email_before_update( self, - mock_features, - mock_csrf, mock_is_freeze, mock_check_unique, mock_get_data, mock_revoke_token, mock_update_account, mock_send_notify, - mock_current_account, - mock_db, app: Flask, ): - mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") - mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True mock_get_data.return_value = _build_change_email_token( @@ -500,46 +419,36 @@ class TestChangeEmailReset: method="POST", json={"new_email": "New@Example.com", "token": "token-123"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) - ChangeEmailResetApi().post() + api = ChangeEmailResetApi() + method = inspect.unwrap(api.post) + method(api, current_user) mock_is_freeze.assert_called_once_with("new@example.com") mock_check_unique.assert_called_once_with("new@example.com") mock_revoke_token.assert_called_once_with("token-123") mock_update_account.assert_called_once_with(current_user, email="new@example.com") mock_send_notify.assert_called_once_with(email="new@example.com") - mock_csrf.assert_called_once() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") @patch("controllers.console.workspace.account.AccountService.update_account_email") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_reset_when_token_phase_is_not_new_verified( self, - mock_features, - mock_csrf, mock_is_freeze, mock_check_unique, mock_get_data, mock_revoke_token, mock_update_account, mock_send_notify, - mock_current_account, - mock_db, app: Flask, ): """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") - mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True mock_get_data.return_value = _build_change_email_token( @@ -554,44 +463,35 @@ class TestChangeEmailReset: method="POST", json={"new_email": "attacker@example.com", "token": "token-from-step1"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailResetApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailResetApi().post() + method(api, current_user) mock_revoke_token.assert_not_called() mock_update_account.assert_not_called() mock_send_notify.assert_not_called() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") @patch("controllers.console.workspace.account.AccountService.update_account_email") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_reset_when_token_email_differs_from_payload_new_email( self, - mock_features, - mock_csrf, mock_is_freeze, mock_check_unique, mock_get_data, mock_revoke_token, mock_update_account, mock_send_notify, - mock_current_account, - mock_db, app: Flask, ): """A verified token for address A must not be replayed to change to address B.""" from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") - mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True mock_get_data.return_value = _build_change_email_token( @@ -606,43 +506,34 @@ class TestChangeEmailReset: method="POST", json={"new_email": "attacker@example.com", "token": "token-verified"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailResetApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailResetApi().post() + method(api, current_user) mock_revoke_token.assert_not_called() mock_update_account.assert_not_called() mock_send_notify.assert_not_called() - @patch("controllers.console.wraps.db") - @patch("controllers.console.workspace.account.current_account_with_tenant") @patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email") @patch("controllers.console.workspace.account.AccountService.update_account_email") @patch("controllers.console.workspace.account.AccountService.revoke_change_email_token") @patch("controllers.console.workspace.account.AccountService.get_change_email_data") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - @patch("libs.login.check_csrf_token", return_value=None) - @patch("controllers.console.wraps.FeatureService.get_system_features") def test_should_reject_reset_when_token_account_id_does_not_match_current_user( self, - mock_features, - mock_csrf, mock_is_freeze, mock_check_unique, mock_get_data, mock_revoke_token, mock_update_account, mock_send_notify, - mock_current_account, - mock_db, app: Flask, ): from controllers.console.auth.error import InvalidTokenError - mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") - mock_current_account.return_value = (current_user, None) mock_is_freeze.return_value = False mock_check_unique.return_value = True mock_get_data.return_value = _build_change_email_token( @@ -657,9 +548,10 @@ class TestChangeEmailReset: method="POST", json={"new_email": "new@example.com", "token": "token-verified"}, ): - _set_logged_in_user(_build_account("tester@example.com", "tester")) + api = ChangeEmailResetApi() + method = inspect.unwrap(api.post) with pytest.raises(InvalidTokenError): - ChangeEmailResetApi().post() + method(api, current_user) mock_revoke_token.assert_not_called() mock_update_account.assert_not_called() @@ -755,25 +647,25 @@ class TestAccountServiceGetChangeEmailData: class TestAccountDeletionFeedback: - @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") - def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask): + def test_should_normalize_feedback_email(self, mock_update, app: Flask): with app.test_request_context( "/account/delete/feedback", method="POST", json={"email": "User@Example.com", "feedback": "test"}, ): - response = AccountDeleteUpdateFeedbackApi().post() + api = AccountDeleteUpdateFeedbackApi() + method = inspect.unwrap(api.post) + response = method(api) assert response == {"result": "success"} mock_update.assert_called_once_with("User@Example.com", "test") class TestCheckEmailUnique: - @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask): + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask): mock_is_freeze.return_value = False mock_check_unique.return_value = True @@ -782,7 +674,9 @@ class TestCheckEmailUnique: method="POST", json={"email": "Case@Test.com"}, ): - response = CheckEmailUnique().post() + api = CheckEmailUnique() + method = inspect.unwrap(api.post) + response = method(api) assert response == {"result": "success"} mock_is_freeze.assert_called_once_with("case@test.com") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index aa58db81da..6f22f5c440 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -1,3 +1,4 @@ +import inspect from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -31,22 +32,29 @@ from controllers.console.workspace.error import ( CurrentPasswordIncorrectError, InvalidAccountDeletionCodeError, ) +from models import Account +from models.account import AccountStatus from models.enums import CreatorUserRole from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +def make_account(account_id: str = "u1", *, status: AccountStatus = AccountStatus.ACTIVE) -> Account: + account = Account(name="John", email=f"{account_id}@test.com", status=status) + account.id = account_id + account.avatar = "avatar.png" + account.interface_language = "en-US" + account.interface_theme = "light" + account.timezone = "UTC" + account.last_login_ip = "127.0.0.1" + return account class TestAccountInitApi: def test_init_success(self, app: Flask): api = AccountInitApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - account = MagicMock(status="inactive") + account = make_account(status=AccountStatus.UNINITIALIZED) payload = { "interface_language": "en-US", "timezone": "UTC", @@ -55,50 +63,35 @@ class TestAccountInitApi: with ( app.test_request_context("/account/init", json=payload), - patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), patch("controllers.console.workspace.account.db.session.commit", return_value=None), patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock, ): scalar_mock.return_value = MagicMock(status="unused") - resp = method(api) + resp = method(api, account) assert resp["result"] == "success" def test_init_already_initialized(self, app: Flask): api = AccountInitApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - account = MagicMock(status="active") + account = make_account() - with ( - app.test_request_context("/account/init"), - patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), - ): + with app.test_request_context("/account/init"): with pytest.raises(AccountAlreadyInitedError): - method(api) + method(api, account) class TestAccountProfileApi: def test_get_profile_success(self, app: Flask): api = AccountProfileApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - user = MagicMock() - user.id = "u1" - user.name = "John" - user.email = "john@test.com" - user.avatar = "avatar.png" - user.interface_language = "en-US" - user.interface_theme = "light" - user.timezone = "UTC" - user.last_login_ip = "127.0.0.1" + user = make_account() - with ( - app.test_request_context("/account/profile"), - patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), - ): - result = method(api) + with app.test_request_context("/account/profile"): + result = method(api, user) assert result["id"] == "u1" @@ -116,24 +109,15 @@ class TestAccountUpdateApis: ) def test_update_success(self, app: Flask, api_cls, payload): api = api_cls() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - user = MagicMock() - user.id = "u1" - user.name = "John" - user.email = "john@test.com" - user.avatar = "avatar.png" - user.interface_language = "en-US" - user.interface_theme = "light" - user.timezone = "UTC" - user.last_login_ip = "127.0.0.1" + user = make_account() with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.account.AccountService.update_account", return_value=user), ): - result = method(api) + result = method(api, user) assert result["id"] == "u1" @@ -143,10 +127,9 @@ class TestAccountAvatarApiGet: def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask): api = AccountAvatarApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - user = MagicMock() - user.id = "acc-owner" + user = make_account("acc-owner") tenant_id = "tenant-1" file_id = "550e8400-e29b-41d4-a716-446655440000" @@ -158,27 +141,22 @@ class TestAccountAvatarApiGet: with ( app.test_request_context(f"/account/avatar?avatar={file_id}"), - patch( - "controllers.console.workspace.account.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file), patch( "controllers.console.workspace.account.file_helpers.get_signed_file_url", return_value="https://signed/example", ) as sign_mock, ): - result = method(api) + result = method(api, tenant_id, user) assert result == {"avatar_url": "https://signed/example"} sign_mock.assert_called_once_with(upload_file_id=file_id) def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask): api = AccountAvatarApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - user = MagicMock() - user.id = "acc-a" + user = make_account("acc-a") tenant_id = "tenant-1" file_id = "550e8400-e29b-41d4-a716-446655440001" @@ -190,10 +168,6 @@ class TestAccountAvatarApiGet: with ( app.test_request_context(f"/account/avatar?avatar={file_id}"), - patch( - "controllers.console.workspace.account.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file), patch( "controllers.console.workspace.account.file_helpers.get_signed_file_url", @@ -201,16 +175,15 @@ class TestAccountAvatarApiGet: ) as sign_mock, ): with pytest.raises(NotFound): - method(api) + method(api, tenant_id, user) sign_mock.assert_not_called() def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask): api = AccountAvatarApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - user = MagicMock() - user.id = "acc-owner" + user = make_account("acc-owner") tenant_id = "tenant-1" file_id = "550e8400-e29b-41d4-a716-446655440002" @@ -222,10 +195,6 @@ class TestAccountAvatarApiGet: with ( app.test_request_context(f"/account/avatar?avatar={file_id}"), - patch( - "controllers.console.workspace.account.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file), patch( "controllers.console.workspace.account.file_helpers.get_signed_file_url", @@ -233,31 +202,26 @@ class TestAccountAvatarApiGet: ) as sign_mock, ): with pytest.raises(NotFound): - method(api) + method(api, tenant_id, user) sign_mock.assert_not_called() def test_get_avatar_https_pass_through_without_signing(self, app: Flask): api = AccountAvatarApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - user = MagicMock() - user.id = "acc-owner" + user = make_account("acc-owner") tenant_id = "tenant-1" external = "https://cdn.example/avatar.png" with ( app.test_request_context(f"/account/avatar?avatar={external}"), - patch( - "controllers.console.workspace.account.current_account_with_tenant", - return_value=(user, tenant_id), - ), patch( "controllers.console.workspace.account.file_helpers.get_signed_file_url", return_value="https://signed/should-not-use", ) as sign_mock, ): - result = method(api) + result = method(api, tenant_id, user) assert result == {"avatar_url": external} sign_mock.assert_not_called() @@ -266,7 +230,7 @@ class TestAccountAvatarApiGet: class TestAccountPasswordApi: def test_password_success(self, app: Flask): api = AccountPasswordApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "password": "old", @@ -274,63 +238,51 @@ class TestAccountPasswordApi: "repeat_new_password": "new123", } - user = MagicMock() - user.id = "u1" - user.name = "John" - user.email = "john@test.com" - user.avatar = "avatar.png" - user.interface_language = "en-US" - user.interface_theme = "light" - user.timezone = "UTC" - user.last_login_ip = "127.0.0.1" + user = make_account() with ( app.test_request_context("/", json=payload), - patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None), ): - result = method(api) + result = method(api, user) assert result["id"] == "u1" def test_password_wrong_current(self, app: Flask): api = AccountPasswordApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "password": "bad", "new_password": "new123", "repeat_new_password": "new123", } + user = make_account() with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch( "controllers.console.workspace.account.AccountService.update_account_password", side_effect=ServicePwdError(), ), ): with pytest.raises(CurrentPasswordIncorrectError): - method(api) + method(api, user) class TestAccountIntegrateApi: def test_get_integrates(self, app: Flask): api = AccountIntegrateApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - account = MagicMock(id="acc1") + account = make_account("acc1") with ( app.test_request_context("/"), - patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock, ): scalars_mock.return_value.all.return_value = [] - result = method(api) + result = method(api, account) assert "data" in result assert len(result["data"]) == 2 @@ -339,13 +291,11 @@ class TestAccountIntegrateApi: class TestAccountDeleteApi: def test_delete_verify_success(self, app: Flask): api = AccountDeleteVerifyApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = make_account() with ( app.test_request_context("/"), - patch( - "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch( "controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code", return_value=("token", "1234"), @@ -355,43 +305,38 @@ class TestAccountDeleteApi: return_value=None, ), ): - result = method(api) + result = method(api, user) assert result["result"] == "success" def test_delete_invalid_code(self, app: Flask): api = AccountDeleteApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"token": "t", "code": "x"} + user = make_account() with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch( "controllers.console.workspace.account.AccountService.verify_account_deletion_code", return_value=False, ), ): with pytest.raises(InvalidAccountDeletionCodeError): - method(api) + method(api, user) class TestChangeEmailApis: def test_check_email_code_invalid(self, app: Flask): api = ChangeEmailCheckApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"email": "a@test.com", "code": "x", "token": "t"} + user = make_account("acc-1") with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.workspace.account.current_account_with_tenant", - return_value=(MagicMock(id="acc-1"), "t1"), - ), patch.object( type(console_ns), "payload", @@ -412,13 +357,14 @@ class TestChangeEmailApis: ), ): with pytest.raises(EmailCodeError): - method(api) + method(api, user) def test_reset_email_already_used(self, app: Flask): api = ChangeEmailResetApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"new_email": "x@test.com", "token": "t"} + user = make_account() with ( app.test_request_context("/", json=payload), @@ -432,13 +378,13 @@ class TestChangeEmailApis: patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False), ): with pytest.raises(EmailAlreadyInUseError): - method(api) + method(api, user) class TestCheckEmailUniqueApi: def test_email_unique_success(self, app: Flask): api = CheckEmailUnique() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"email": "ok@test.com"} @@ -459,7 +405,7 @@ class TestCheckEmailUniqueApi: def test_email_in_freeze(self, app: Flask): api = CheckEmailUnique() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"email": "x@test.com"} diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py index ed7b2d606f..abd9b4facb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -1,4 +1,5 @@ -from unittest.mock import MagicMock, patch +import inspect +from unittest.mock import patch import pytest from flask import Flask @@ -18,31 +19,10 @@ from controllers.console.workspace.endpoint import ( from core.plugin.impl.exc import PluginPermissionDeniedError -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - -@pytest.fixture -def user_and_tenant(): - return MagicMock(id="u1"), "t1" - - -@pytest.fixture -def patch_current_account(user_and_tenant): - with patch( - "controllers.console.workspace.endpoint.current_account_with_tenant", - return_value=user_and_tenant, - ): - yield - - -@pytest.mark.usefixtures("patch_current_account") class TestEndpointCollectionApi: def test_create_success(self, app: Flask): api = EndpointCollectionApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "plugin_unique_identifier": "plugin-1", @@ -54,13 +34,13 @@ class TestEndpointCollectionApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is True def test_create_permission_denied(self, app: Flask): api = EndpointCollectionApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "plugin_unique_identifier": "plugin-1", @@ -76,11 +56,11 @@ class TestEndpointCollectionApi: ), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") def test_create_validation_error(self, app: Flask): api = EndpointCollectionApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "plugin_unique_identifier": "p1", @@ -92,14 +72,13 @@ class TestEndpointCollectionApi: app.test_request_context("/", json=payload), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") -@pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointCreateApi: def test_create_success(self, app: Flask): api = DeprecatedEndpointCreateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "plugin_unique_identifier": "plugin-1", @@ -111,42 +90,40 @@ class TestDeprecatedEndpointCreateApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is True -@pytest.mark.usefixtures("patch_current_account") class TestEndpointListApi: def test_list_success(self, app: Flask): api = EndpointListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?page=1&page_size=10"), patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]), ): - result = method(api) + result = method(api, "t1", "u1") assert "endpoints" in result assert len(result["endpoints"]) == 1 def test_list_invalid_query(self, app: Flask): api = EndpointListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?page=0&page_size=10"), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") -@pytest.mark.usefixtures("patch_current_account") class TestEndpointListForSinglePluginApi: def test_list_for_plugin_success(self, app: Flask): api = EndpointListForSinglePluginApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?page=1&page_size=10&plugin_id=p1"), @@ -155,26 +132,25 @@ class TestEndpointListForSinglePluginApi: return_value=[{"id": "e1"}], ), ): - result = method(api) + result = method(api, "t1", "u1") assert "endpoints" in result def test_list_for_plugin_missing_param(self, app: Flask): api = EndpointListForSinglePluginApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) with ( app.test_request_context("/?page=1&page_size=10"), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") -@pytest.mark.usefixtures("patch_current_account") class TestEndpointItemApi: def test_delete_success(self, app: Flask): api = EndpointItemApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with ( app.test_request_context("/", method="DELETE"), @@ -183,26 +159,26 @@ class TestEndpointItemApi: return_value=True, ) as mock_delete, ): - result = method(api, "e1") + result = method(api, "t1", "u1", "e1") assert result["success"] is True mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1") def test_delete_service_failure(self, app: Flask): api = EndpointItemApi() - method = unwrap(api.delete) + method = inspect.unwrap(api.delete) with ( app.test_request_context("/", method="DELETE"), patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), ): - result = method(api, "e1") + result = method(api, "t1", "u1", "e1") assert result["success"] is False def test_update_success(self, app: Flask): api = EndpointItemApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) payload = { "name": "new-name", @@ -216,7 +192,7 @@ class TestEndpointItemApi: return_value=True, ) as mock_update, ): - result = method(api, "e1") + result = method(api, "t1", "u1", "e1") assert result["success"] is True mock_update.assert_called_once_with( @@ -229,7 +205,7 @@ class TestEndpointItemApi: def test_update_validation_error(self, app: Flask): api = EndpointItemApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) payload = {"settings": {}} @@ -237,11 +213,11 @@ class TestEndpointItemApi: app.test_request_context("/", method="PATCH", json=payload), ): with pytest.raises(ValueError): - method(api, "e1") + method(api, "t1", "u1", "e1") def test_update_service_failure(self, app: Flask): api = EndpointItemApi() - method = unwrap(api.patch) + method = inspect.unwrap(api.patch) payload = { "name": "n", @@ -252,16 +228,15 @@ class TestEndpointItemApi: app.test_request_context("/", method="PATCH", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), ): - result = method(api, "e1") + result = method(api, "t1", "u1", "e1") assert result["success"] is False -@pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointDeleteApi: def test_delete_success(self, app: Flask): api = DeprecatedEndpointDeleteApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -269,23 +244,23 @@ class TestDeprecatedEndpointDeleteApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is True def test_delete_invalid_payload(self, app: Flask): api = DeprecatedEndpointDeleteApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) with ( app.test_request_context("/", json={}), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") def test_delete_service_failure(self, app: Flask): api = DeprecatedEndpointDeleteApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -293,16 +268,15 @@ class TestDeprecatedEndpointDeleteApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is False -@pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointUpdateApi: def test_update_success(self, app: Flask): api = DeprecatedEndpointUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "endpoint_id": "e1", @@ -314,13 +288,13 @@ class TestDeprecatedEndpointUpdateApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is True def test_update_validation_error(self, app: Flask): api = DeprecatedEndpointUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"endpoint_id": "e1", "settings": {}} @@ -328,11 +302,11 @@ class TestDeprecatedEndpointUpdateApi: app.test_request_context("/", json=payload), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") def test_update_service_failure(self, app: Flask): api = DeprecatedEndpointUpdateApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = { "endpoint_id": "e1", @@ -344,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is False @@ -379,11 +353,10 @@ class TestEndpointRouteMetadata: assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",) -@pytest.mark.usefixtures("patch_current_account") class TestEndpointEnableApi: def test_enable_success(self, app: Flask): api = EndpointEnableApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -391,23 +364,23 @@ class TestEndpointEnableApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is True def test_enable_invalid_payload(self, app: Flask): api = EndpointEnableApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) with ( app.test_request_context("/", json={}), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") def test_enable_service_failure(self, app: Flask): api = EndpointEnableApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -415,16 +388,15 @@ class TestEndpointEnableApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is False -@pytest.mark.usefixtures("patch_current_account") class TestEndpointDisableApi: def test_disable_success(self, app: Flask): api = EndpointDisableApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"endpoint_id": "e1"} @@ -432,16 +404,16 @@ class TestEndpointDisableApi: app.test_request_context("/", json=payload), patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True), ): - result = method(api) + result = method(api, "t1", "u1") assert result["success"] is True def test_disable_invalid_payload(self, app: Flask): api = EndpointDisableApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) with ( app.test_request_context("/", json={}), ): with pytest.raises(ValueError): - method(api) + method(api, "t1", "u1") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index a294e8e893..68d5a879e4 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -1,3 +1,4 @@ +import inspect from io import BytesIO from unittest.mock import MagicMock, patch @@ -28,38 +29,47 @@ from controllers.console.workspace.workspace import ( ) from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now -from models.account import TenantStatus +from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +def make_account(account_id: str = "u1") -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + return account + + +def make_tenant( + tenant_id: str = "t1", + *, + name: str | None = None, + status: TenantStatus = TenantStatus.NORMAL, + custom_config: TenantCustomConfigDict | None = None, +) -> Tenant: + tenant = Tenant(name=name or f"Tenant {tenant_id}", status=status) + tenant.id = tenant_id + tenant.created_at = naive_utc_now() + if custom_config is not None: + tenant.custom_config_dict = custom_config + return tenant + + +def make_account_with_tenant(tenant: Tenant) -> Account: + account = make_account() + account._current_tenant = tenant + return account class TestTenantListApi: def test_get_success_saas_path(self, app: Flask): api = TenantListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant1 = MagicMock( - id="t1", - name="Tenant 1", - status="active", - created_at=naive_utc_now(), - ) - tenant2 = MagicMock( - id="t2", - name="Tenant 2", - status="active", - created_at=naive_utc_now(), - ) + tenant1 = make_tenant("t1", name="Tenant 1") + tenant2 = make_tenant("t2", name="Tenant 2") + user = make_account() with ( app.test_request_context("/workspaces"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant1, tenant2], @@ -76,7 +86,7 @@ class TestTenantListApi: ) as get_plan_bulk_mock, patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, ): - result, status = method(api) + result, status = method(api, "t1", user) assert status == 200 assert len(result["workspaces"]) == 2 @@ -93,30 +103,18 @@ class TestTenantListApi: (SaaS contract treats enabled as on; display follows subscription.plan). """ api = TenantListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant1 = MagicMock( - id="t1", - name="Tenant 1", - status="active", - created_at=naive_utc_now(), - ) - tenant2 = MagicMock( - id="t2", - name="Tenant 2", - status="active", - created_at=naive_utc_now(), - ) + tenant1 = make_tenant("t1", name="Tenant 1") + tenant2 = make_tenant("t2", name="Tenant 2") features_t2 = MagicMock() features_t2.billing.enabled = False features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL + user = make_account() with ( app.test_request_context("/workspaces"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant1, tenant2], @@ -133,7 +131,7 @@ class TestTenantListApi: return_value=features_t2, ) as get_features_mock, ): - result, status = method(api) + result, status = method(api, "t1", user) assert status == 200 assert result["workspaces"][0]["plan"] == CloudPlan.TEAM @@ -148,30 +146,18 @@ class TestTenantListApi: so we simulate the real failure mode by returning empty dict for non-empty input. """ api = TenantListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant1 = MagicMock( - id="t1", - name="Tenant 1", - status="active", - created_at=naive_utc_now(), - ) - tenant2 = MagicMock( - id="t2", - name="Tenant 2", - status="active", - created_at=naive_utc_now(), - ) + tenant1 = make_tenant("t1", name="Tenant 1") + tenant2 = make_tenant("t2", name="Tenant 2") features = MagicMock() features.billing.enabled = False features.billing.subscription.plan = CloudPlan.TEAM + user = make_account() with ( app.test_request_context("/workspaces"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") - ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant1, tenant2], @@ -189,7 +175,7 @@ class TestTenantListApi: ) as get_features_mock, patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock, ): - result, status = method(api) + result, status = method(api, "t2", user) assert status == 200 assert result["workspaces"][0]["plan"] == CloudPlan.TEAM @@ -200,25 +186,17 @@ class TestTenantListApi: def test_get_billing_disabled_community_path(self, app: Flask): api = TenantListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant = MagicMock( - id="t1", - name="Tenant", - status="active", - created_at=naive_utc_now(), - ) + tenant = make_tenant("t1", name="Tenant") features = MagicMock() features.billing.enabled = False features.billing.subscription.plan = CloudPlan.SANDBOX + user = make_account() with ( app.test_request_context("/workspaces"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant], @@ -231,7 +209,7 @@ class TestTenantListApi: return_value=features, ) as get_features_mock, ): - result, status = method(api) + result, status = method(api, "t1", user) assert status == 200 assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX @@ -239,26 +217,14 @@ class TestTenantListApi: def test_get_enterprise_only_skips_feature_service(self, app: Flask): api = TenantListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant1 = MagicMock( - id="t1", - name="Tenant 1", - status="active", - created_at=naive_utc_now(), - ) - tenant2 = MagicMock( - id="t2", - name="Tenant 2", - status="active", - created_at=naive_utc_now(), - ) + tenant1 = make_tenant("t1", name="Tenant 1") + tenant2 = make_tenant("t2", name="Tenant 2") + user = make_account() with ( app.test_request_context("/workspaces"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") - ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[tenant1, tenant2], @@ -268,7 +234,7 @@ class TestTenantListApi: patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, ): - result, status = method(api) + result, status = method(api, "t2", user) assert status == 200 assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX @@ -279,13 +245,11 @@ class TestTenantListApi: def test_get_enterprise_only_with_empty_tenants(self, app: Flask): api = TenantListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) + user = make_account() with ( app.test_request_context("/workspaces"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None) - ), patch( "controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[], @@ -295,7 +259,7 @@ class TestTenantListApi: patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, ): - result, status = method(api) + result, status = method(api, None, user) assert status == 200 assert result["workspaces"] == [] @@ -305,9 +269,9 @@ class TestTenantListApi: class TestWorkspaceListApi: def test_get_success(self, app: Flask): api = WorkspaceListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now()) + tenant = make_tenant("t1", name="T") paginate_result = MagicMock(items=[tenant], has_next=False, total=1) with ( @@ -322,9 +286,9 @@ class TestWorkspaceListApi: def test_get_has_next_true(self, app: Flask): api = WorkspaceListApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now()) + tenant = make_tenant("t1", name="T") paginate_result = MagicMock(items=[tenant], has_next=True, total=10) with ( @@ -340,80 +304,71 @@ class TestWorkspaceListApi: class TestTenantApi: def test_post_active_tenant(self, app: Flask): api = TenantApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - tenant = MagicMock(status="active") - - user = MagicMock(current_tenant=tenant) + tenant = make_tenant() + user = make_account_with_tenant(tenant) with ( app.test_request_context("/workspaces/current"), - patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} ), ): - result, status = method(api) + result, status = method(api, user) assert status == 200 assert result["id"] == "t1" def test_post_archived_with_switch(self, app: Flask): api = TenantApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - archived = MagicMock(status=TenantStatus.ARCHIVE) - new_tenant = MagicMock(status="active") - - user = MagicMock(current_tenant=archived) + archived = make_tenant(status=TenantStatus.ARCHIVE) + new_tenant = make_tenant("new") + user = make_account_with_tenant(archived) with ( app.test_request_context("/workspaces/current"), - patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"} ), ): - result, status = method(api) + result, status = method(api, user) assert result["id"] == "new" def test_post_archived_no_tenant(self, app: Flask): api = TenantApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE)) + user = make_account_with_tenant(make_tenant(status=TenantStatus.ARCHIVE)) with ( app.test_request_context("/workspaces/current"), - patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]), ): with pytest.raises(Unauthorized): - method(api) + method(api, user) def test_post_info_path(self, app: Flask): api = TenantApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - tenant = MagicMock(status="active") - user = MagicMock(current_tenant=tenant) + tenant = make_tenant() + user = make_account_with_tenant(tenant) with ( app.test_request_context("/info"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(user, "t1"), - ), patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}, ), patch("controllers.console.workspace.workspace.logger.warning") as warn_mock, ): - result, status = method(api) + result, status = method(api, user) warn_mock.assert_called_once() assert status == 200 @@ -456,16 +411,14 @@ class TestTenantInfoResponse: class TestSwitchWorkspaceApi: def test_switch_success(self, app: Flask): api = SwitchWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"tenant_id": "t2"} - tenant = MagicMock(id="t2") + tenant = make_tenant("t2") + user = make_account() with ( app.test_request_context("/workspaces/switch", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), patch("controllers.console.workspace.workspace.db.session.get") as get_mock, patch( @@ -473,85 +426,73 @@ class TestSwitchWorkspaceApi: ), ): get_mock.return_value = tenant - result = method(api) + result = method(api, user) assert result["result"] == "success" def test_switch_not_linked(self, app: Flask): api = SwitchWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"tenant_id": "bad"} + user = make_account() with ( app.test_request_context("/workspaces/switch", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception), ): with pytest.raises(AccountNotLinkTenantError): - method(api) + method(api, user) def test_switch_tenant_not_found(self, app: Flask): api = SwitchWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"tenant_id": "missing"} + user = make_account() with ( app.test_request_context("/workspaces/switch", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), patch("controllers.console.workspace.workspace.db.session.get") as get_mock, ): get_mock.return_value = None with pytest.raises(ValueError): - method(api) + method(api, user) class TestCustomConfigWorkspaceApi: def test_post_success(self, app: Flask): api = CustomConfigWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - tenant = MagicMock(custom_config_dict={}) + tenant = make_tenant(custom_config={}) payload = {"remove_webapp_brand": True} with ( app.test_request_context("/workspaces/custom-config", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), patch("controllers.console.workspace.workspace.db.session.commit"), patch( "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} ), ): - result = method(api) + result = method(api, "t1") assert result["result"] == "success" def test_logo_fallback(self, app: Flask): api = CustomConfigWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"}) + tenant = make_tenant(custom_config={"replace_webapp_logo": "old-logo"}) payload = {"remove_webapp_brand": False} with ( app.test_request_context("/workspaces/custom-config", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), patch( "controllers.console.workspace.workspace.db.get_or_404", return_value=tenant, @@ -562,7 +503,7 @@ class TestCustomConfigWorkspaceApi: return_value={"id": "t1"}, ), ): - result = method(api) + result = method(api, "t1") assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo" assert result["result"] == "success" @@ -571,54 +512,41 @@ class TestCustomConfigWorkspaceApi: class TestWebappLogoWorkspaceApi: def test_no_file(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) + user = make_account() - with ( - app.test_request_context("/upload", data={}), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), - ): + with app.test_request_context("/upload", data={}): with pytest.raises(NoFileUploadedError): - method(api) + method(api, user) def test_too_many_files(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) data = { "file": MagicMock(), "extra": MagicMock(), } + user = make_account() - with ( - app.test_request_context("/upload", data=data), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), - ): + with app.test_request_context("/upload", data=data): with pytest.raises(TooManyFilesError): - method(api) + method(api, user) def test_invalid_extension(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) file = MagicMock(filename="test.txt") + user = make_account() - with ( - app.test_request_context("/upload", data={"file": file}), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), - ): + with app.test_request_context("/upload", data={"file": file}): with pytest.raises(UnsupportedFileTypeError): - method(api) + method(api, user) def test_upload_success(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) file = FileStorage( stream=BytesIO(b"data"), @@ -627,6 +555,7 @@ class TestWebappLogoWorkspaceApi: ) upload = MagicMock(id="file1") + user = make_account() with ( app.test_request_context( @@ -634,53 +563,46 @@ class TestWebappLogoWorkspaceApi: data={"file": file}, content_type="multipart/form-data", ), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch("controllers.console.workspace.workspace.FileService") as fs, patch("controllers.console.workspace.workspace.db") as mock_db, ): mock_db.engine = MagicMock() fs.return_value.upload_file.return_value = upload - result, status = method(api) + result, status = method(api, user) assert status == 201 assert result["id"] == "file1" def test_filename_missing(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) file = FileStorage( stream=BytesIO(b"data"), filename="", content_type="image/png", ) + user = make_account() - with ( - app.test_request_context( - "/upload", - data={"file": file}, - content_type="multipart/form-data", - ), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), + with app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", ): with pytest.raises(FilenameNotExistsError): - method(api) + method(api, user) def test_file_too_large(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) file = FileStorage( stream=BytesIO(b"x"), filename="logo.png", content_type="image/png", ) + user = make_account() with ( app.test_request_context( @@ -688,10 +610,6 @@ class TestWebappLogoWorkspaceApi: data={"file": file}, content_type="multipart/form-data", ), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), patch("controllers.console.workspace.workspace.FileService") as fs, patch("controllers.console.workspace.workspace.db") as mock_db, ): @@ -699,17 +617,18 @@ class TestWebappLogoWorkspaceApi: fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big") with pytest.raises(FileTooLargeError): - method(api) + method(api, user) def test_service_unsupported_file(self, app: Flask): api = WebappLogoWorkspaceApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) file = FileStorage( stream=BytesIO(b"x"), filename="logo.png", content_type="image/png", ) + user = make_account() with ( app.test_request_context( @@ -717,10 +636,6 @@ class TestWebappLogoWorkspaceApi: data={"file": file}, content_type="multipart/form-data", ), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), "t1"), - ), patch("controllers.console.workspace.workspace.FileService") as fs, patch("controllers.console.workspace.workspace.db") as mock_db, ): @@ -728,23 +643,20 @@ class TestWebappLogoWorkspaceApi: fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError() with pytest.raises(UnsupportedFileTypeError): - method(api) + method(api, user) class TestWorkspaceInfoApi: def test_post_success(self, app: Flask): api = WorkspaceInfoApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) - tenant = MagicMock() + tenant = make_tenant() payload = {"name": "New Name"} with ( app.test_request_context("/workspaces/info", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), patch("controllers.console.workspace.workspace.db.session.commit"), patch( @@ -752,31 +664,27 @@ class TestWorkspaceInfoApi: return_value={"name": "New Name"}, ), ): - result = method(api) + result = method(api, "t1") assert result["result"] == "success" def test_no_current_tenant(self, app: Flask): api = WorkspaceInfoApi() - method = unwrap(api.post) + method = inspect.unwrap(api.post) payload = {"name": "X"} with ( app.test_request_context("/workspaces/info", json=payload), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), None), - ), ): with pytest.raises(ValueError): - method(api) + method(api, None) class TestWorkspacePermissionApi: def test_get_success(self, app: Flask): api = WorkspacePermissionApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) permission = MagicMock( workspace_id="t1", @@ -786,29 +694,20 @@ class TestWorkspacePermissionApi: with ( app.test_request_context("/permission"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") - ), patch( "controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission", return_value=permission, ), ): - result, status = method(api) + result, status = method(api, "t1") assert status == 200 assert result["workspace_id"] == "t1" def test_no_current_tenant(self, app: Flask): api = WorkspacePermissionApi() - method = unwrap(api.get) + method = inspect.unwrap(api.get) - with ( - app.test_request_context("/permission"), - patch( - "controllers.console.workspace.workspace.current_account_with_tenant", - return_value=(MagicMock(), None), - ), - ): + with app.test_request_context("/permission"): with pytest.raises(ValueError): - method(api) + method(api, None) diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py index ea60b94b61..10b80fb92f 100644 --- a/api/tests/unit_tests/services/controller_api.py +++ b/api/tests/unit_tests/services/controller_api.py @@ -82,11 +82,13 @@ This test suite follows a comprehensive testing strategy that covers: ================================================================================ """ +from collections.abc import Iterator +from types import SimpleNamespace from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from flask import Flask +from flask import Flask, g from flask.testing import FlaskClient from flask_restx import Api @@ -95,6 +97,7 @@ from controllers.console.datasets.external import ( ExternalApiTemplateListApi, ) from controllers.console.datasets.hit_testing import HitTestingApi +from models.account import Account, AccountStatus, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum # ============================================================================ @@ -817,15 +820,32 @@ class TestExternalDatasetApi: ) @pytest.fixture - def mock_current_user(self): - """Mock current user and tenant context.""" - with patch("controllers.console.datasets.external.current_account_with_tenant") as mock_get_user: - mock_user = ControllerApiTestDataFactory.create_user_mock(is_dataset_editor=True) + def mock_current_account_context(self, app: Flask) -> Iterator[Mock]: + """Provide the wrapper auth context required by HTTP-client controller tests.""" + mock_user = Account(name="Test User", email="user-123@example.com") + mock_user.id = "user-123" + mock_user.status = AccountStatus.ACTIVE + mock_user.role = TenantAccountRole.EDITOR + + def load_user_from_request_context() -> None: + g._login_user = mock_user + + setattr( # noqa: B010 + app, + "login_manager", + SimpleNamespace(load_user_from_request_context=load_user_from_request_context), + ) + + with ( + patch("controllers.console.wraps.current_account_with_tenant") as mock_get_user, + patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), + patch("libs.login.check_csrf_token", return_value=None), + ): mock_tenant_id = "tenant-123" mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_get_external_knowledge_apis_success(self, client_list, mock_current_user): + def test_get_external_knowledge_apis_success(self, client_list: FlaskClient, mock_current_account_context: Mock): """ Test successful retrieval of external knowledge API list. @@ -839,7 +859,11 @@ class TestExternalDatasetApi: - Status code is 200 """ # Arrange - apis = [{"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"} for i in range(3)] + apis = [] + for i in range(3): + api_item = Mock() + api_item.to_dict.return_value = {"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"} + apis.append(api_item) with patch( "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis" @@ -847,6 +871,7 @@ class TestExternalDatasetApi: mock_get_apis.return_value = (apis, 3) # Act + # TODO: this should be made integrated tests... response = client_list.get("/datasets/external-knowledge-api?page=1&limit=20") # Assert