refactor(api): migrate tenant/user via DI for several endpoints (#37026)

This commit is contained in:
chariri 2026-06-04 14:52:59 +09:00 committed by GitHub
parent 5b5a06136a
commit b67c3a5f76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 1186 additions and 1580 deletions

View File

@ -1,7 +1,7 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any, TypedDict
from typing import Any, Concatenate, TypedDict
from uuid import UUID
from flask import Response, request
@ -214,7 +214,9 @@ workflow_draft_variable_list_model = console_ns.model(
)
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite[T, **P, R](
f: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -231,8 +233,8 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(*args, **kwargs)
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
return f(self, *args, **kwargs)
return wrapper

View File

@ -24,14 +24,14 @@ from extensions.ext_database import db
from fields.base import ResponseModel
from libs.datetime_utils import naive_utc_now
from libs.helper import dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DataSourceOauthBinding, Document
from libs.login import login_required
from models import Account, DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
from services.datasource_provider_service import DatasourceProviderService
from tasks.document_indexing_sync_task import document_indexing_sync_task
from .. import console_ns
from ..wraps import account_initialization_required, setup_required
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
class NotionEstimatePayload(BaseModel):
@ -130,9 +130,8 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
def get(self) -> tuple[dict[str, Any], int]:
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
# get workspace data source integrates
data_source_integrates = db.session.scalars(
select(DataSourceOauthBinding).where(
@ -180,8 +179,10 @@ class DataSourceApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, binding_id: UUID, action: Literal["enable", "disable"]) -> tuple[dict[str, str], int]:
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def patch(
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
) -> tuple[dict[str, str], int]:
binding_id_str = str(binding_id)
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
data_source_binding = session.execute(
@ -220,9 +221,9 @@ class DataSourceNotionListApi(Resource):
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
def get(self) -> tuple[dict[str, Any], int]:
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
@ -311,9 +312,8 @@ class DataSourceNotionPreviewApi(Resource):
@account_initialization_required
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
def get(self, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
datasource_provider_service = DatasourceProviderService()
@ -347,9 +347,8 @@ class DataSourceNotionIndexingEstimateApi(Resource):
@account_initialization_required
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
def post(self) -> tuple[dict[str, Any], int]:
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump()
# validate args

View File

@ -44,8 +44,8 @@ from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
from libs.datetime_utils import naive_utc_now
from libs.helper import dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from libs.login import login_required
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.enums import IndexingStatus, SegmentStatus
from services.dataset_service import DatasetService, DocumentService
@ -71,6 +71,8 @@ from ..wraps import (
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
logger = logging.getLogger(__name__)
@ -169,8 +171,9 @@ register_response_schema_models(
class DocumentResource(Resource):
def get_document(self, dataset_id: str, document_id: str) -> Document:
current_user, current_tenant_id = current_account_with_tenant()
def get_document(
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
) -> Document:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -190,8 +193,7 @@ class DocumentResource(Resource):
return document
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
current_user, _ = current_account_with_tenant()
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
@ -218,8 +220,8 @@ class GetProcessRuleApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
req_data = request.args
document_id = req_data.get("document_id")
@ -279,8 +281,9 @@ class DatasetDocumentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
@ -405,8 +408,8 @@ class DatasetDocumentListApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -480,9 +483,10 @@ class DatasetInitApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
current_user, current_tenant_id = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
@ -539,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
raise DocumentAlreadyFinishedError()
@ -604,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
if not documents:
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule
@ -704,9 +710,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, batch: str):
@with_current_user
def get(self, current_user: Account, dataset_id: UUID, batch: str):
dataset_id_str = str(dataset_id)
documents = self.get_batch_documents(dataset_id_str, batch)
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
documents_status = []
for document in documents:
completed_segments = (
@ -759,16 +766,18 @@ class DocumentIndexingStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
completed_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == document_id_str,
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -777,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource):
total_segments = (
db.session.scalar(
select(func.count(DocumentSegment.id)).where(
DocumentSegment.document_id == str(document_id_str),
DocumentSegment.document_id == document_id_str,
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
)
)
@ -820,10 +829,12 @@ class DocumentApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
metadata = request.args.get("metadata", "all")
if metadata not in self.METADATA_CHOICES:
@ -909,7 +920,9 @@ class DocumentApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(204, "Document deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
@with_current_user
@with_current_tenant_id
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -918,7 +931,7 @@ class DocumentApi(DocumentResource):
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
try:
DocumentService.delete_document(document)
@ -939,9 +952,11 @@ class DocumentDownloadApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
# Reuse the shared permission/tenant checks implemented in DocumentResource.
document = self.get_document(str(dataset_id), str(document_id))
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
return {"url": DocumentService.get_document_download_url(document)}
@ -956,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource):
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
def post(self, dataset_id: UUID):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
"""Stream a ZIP archive containing the requested uploaded documents."""
# Parse and validate request payload.
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
current_user, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
@ -1003,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
current_user, _ = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
action: Literal["pause", "resume"],
):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
if not current_user.is_dataset_editor:
@ -1051,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def put(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
document = self.get_document(dataset_id_str, document_id_str)
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
@ -1100,8 +1125,10 @@ class DocumentStatusApi(DocumentResource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
@ -1216,8 +1243,6 @@ class DocumentRetryApi(DocumentResource):
raise NotFound("Dataset not found.")
for document_id in payload.document_ids:
try:
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
# 404 if document not found
@ -1248,9 +1273,9 @@ class DocumentRenameApi(DocumentResource):
@account_initialization_required
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
@with_current_user
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
current_user, _ = current_account_with_tenant()
if not current_user.is_dataset_editor:
raise Forbidden()
dataset = DatasetService.get_dataset(dataset_id)
@ -1273,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
"""sync website document."""
_, current_tenant_id = current_account_with_tenant()
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if not dataset:
@ -1351,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id: UUID):
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
"""
Generate summary index for specified documents.
@ -1359,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource):
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
# Get dataset
@ -1444,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
@with_current_user
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
"""
Get summary index generation status for a document.
@ -1457,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource):
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)

View File

@ -33,6 +33,8 @@ from controllers.console.wraps import (
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
@ -51,7 +53,8 @@ from fields.segment_fields import (
)
from graphon.model_runtime.entities.model_entities import ModelType
from libs.helper import dump_response, escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
@ -164,9 +167,9 @@ class DatasetDocumentSegmentListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -274,9 +277,8 @@ class DatasetDocumentSegmentListApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
@console_ns.response(204, "Segments deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -312,9 +314,16 @@ class DatasetDocumentSegmentApi(Resource):
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
action: Literal["enable", "disable"],
):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if not dataset:
@ -373,9 +382,9 @@ class DatasetDocumentSegmentAddApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -431,9 +440,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -500,9 +511,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
@console_ns.response(204, "Segment deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def delete(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -548,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
@cloud_edition_billing_knowledge_limit_check("add_segment")
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
def post(self, dataset_id: UUID, document_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -619,9 +632,11 @@ class ChildChunkAddApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -677,9 +692,8 @@ class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -731,9 +745,11 @@ class ChildChunkAddApi(Resource):
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
)
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -781,9 +797,17 @@ class ChildChunkUpdateApi(Resource):
@cloud_edition_billing_rate_limit_check("knowledge")
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.response(204, "Child chunk deleted successfully")
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def delete(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
segment_id: UUID,
child_chunk_id: UUID,
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -840,9 +864,17 @@ class ChildChunkUpdateApi(Resource):
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(
self,
current_tenant_id: str,
current_user: Account,
dataset_id: UUID,
document_id: UUID,
segment_id: UUID,
child_chunk_id: UUID,
):
# check dataset
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)

View File

@ -15,6 +15,7 @@ from controllers.console.wraps import (
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from fields.dataset_fields import (
dataset_detail_fields,
@ -29,7 +30,8 @@ from fields.dataset_fields import (
vector_setting_fields,
weighted_score_fields,
)
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
@ -152,8 +154,9 @@ class ExternalApiTemplateListApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def post(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
ExternalDatasetService.validate_api_list(payload.settings)
@ -182,8 +185,8 @@ class ExternalApiTemplateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
external_knowledge_api_id_str, current_tenant_id
@ -197,8 +200,9 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
def patch(self, external_knowledge_api_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
@ -217,8 +221,9 @@ class ExternalApiTemplateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(204, "External knowledge API deleted successfully")
def delete(self, external_knowledge_api_id: UUID):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
@ -237,8 +242,8 @@ class ExternalApiUseCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, external_knowledge_api_id: UUID):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
external_knowledge_api_id_str = str(external_knowledge_api_id)
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
@ -259,9 +264,10 @@ class ExternalDatasetCreateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
@with_current_user
@with_current_tenant_id
def post(self, current_tenant_id: str, current_user: Account):
# The role of the current user in the ta table must be admin, owner, or editor
current_user, current_tenant_id = current_account_with_tenant()
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
@ -293,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, dataset_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:

View File

@ -9,11 +9,18 @@ from configs import dify_config
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from core.plugin.impl.oauth import OAuthHandler
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def get(self, provider_id: str):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
tenant_id = current_tenant_id
credential_id = request.args.get("credential_id")
datasource_provider_id = DatasourceProviderID(provider_id)
provider_name = datasource_provider_id.provider_name
@ -174,9 +180,8 @@ class DatasourceAuth(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -195,10 +200,11 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider_id: str):
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, user: Account, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
user, current_tenant_id = current_account_with_tenant()
datasources = datasource_provider_service.list_datasource_credentials(
tenant_id=current_tenant_id,
@ -217,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
plugin_id = datasource_provider_id.plugin_id
provider_name = datasource_provider_id.provider_name
@ -242,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
@ -265,9 +269,8 @@ class DatasourceAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -278,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
return {"result": jsonable_encoder(datasources)}, 200
@ -293,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -311,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def delete(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def delete(self, current_tenant_id: str, provider_id: str):
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_oauth_custom_client_params(
@ -331,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()
@ -353,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource):
@login_required
@account_initialization_required
@edit_permission_required
def post(self, provider_id: str):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str, provider_id: str):
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
datasource_provider_id = DatasourceProviderID(provider_id)
datasource_provider_service = DatasourceProviderService()

View File

@ -1,6 +1,6 @@
import logging
from collections.abc import Callable
from typing import Any, NoReturn
from typing import Any, Concatenate, NoReturn
from uuid import UUID
from flask import Response, request
@ -57,7 +57,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
def _api_prerequisite[T, **P, R](
f: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]:
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
@ -72,10 +74,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
@login_required
@account_initialization_required
@get_rag_pipeline
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)
return f(self, *args, **kwargs)
return wrapper

View File

@ -42,6 +42,8 @@ from controllers.console.wraps import (
enterprise_license_required,
only_edition_cloud,
setup_required,
with_current_tenant_id,
with_current_user,
)
from extensions.ext_database import db
from fields.base import ResponseModel
@ -49,8 +51,8 @@ from fields.member_fields import Account as AccountResponse
from graphon.file import helpers as file_helpers
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models import AccountIntegrate, InvitationCode
from libs.login import login_required
from models import Account, AccountIntegrate, InvitationCode
from models.account import AccountStatus, InvitationCodeStatus
from models.enums import CreatorUserRole
from models.model import UploadFile
@ -258,9 +260,8 @@ class AccountInitApi(Resource):
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
@setup_required
@login_required
def post(self):
account, _ = current_account_with_tenant()
@with_current_user
def post(self, account: Account):
if account.status == "active":
raise AccountAlreadyInitedError()
@ -306,8 +307,8 @@ class AccountProfileApi(Resource):
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
@enterprise_license_required
def get(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account):
return _serialize_account(current_user)
@ -318,8 +319,8 @@ class AccountNameApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountNamePayload.model_validate(payload)
updated_account = AccountService.update_account(current_user, name=args.name)
@ -336,8 +337,9 @@ class AccountAvatarApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
avatar = args.avatar
@ -362,8 +364,8 @@ class AccountAvatarApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountAvatarPayload.model_validate(payload)
@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountInterfaceLanguagePayload.model_validate(payload)
@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountInterfaceThemePayload.model_validate(payload)
@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountTimezonePayload.model_validate(payload)
@ -430,8 +432,8 @@ class AccountPasswordApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = AccountPasswordPayload.model_validate(payload)
@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
account_integrates = db.session.scalars(
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
).all()
@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
token, code = AccountService.generate_account_deletion_verification_code(account)
AccountService.send_account_deletion_verification_email(account, code)
@ -511,9 +511,8 @@ class AccountDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
account, _ = current_account_with_tenant()
@with_current_user
def post(self, account: Account):
payload = console_ns.payload or {}
args = AccountDeletePayload.model_validate(payload)
@ -547,9 +546,8 @@ class EducationVerifyApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
return EducationVerifyResponse.model_validate(
BillingService.EducationIdentity.verify(account.id, account.email) or {}
).model_dump(mode="json")
@ -563,9 +561,8 @@ class EducationApi(Resource):
@account_initialization_required
@only_edition_cloud
@cloud_edition_billing_enabled
def post(self):
account, _ = current_account_with_tenant()
@with_current_user
def post(self, account: Account):
payload = console_ns.payload or {}
args = EducationActivatePayload.model_validate(payload)
@ -577,9 +574,8 @@ class EducationApi(Resource):
@only_edition_cloud
@cloud_edition_billing_enabled
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
def get(self):
account, _ = current_account_with_tenant()
@with_current_user
def get(self, account: Account):
res = BillingService.EducationIdentity.status(account.id) or {}
# convert expire_at to UTC timestamp from isoformat
if res and "expire_at" in res:
@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = ChangeEmailSendPayload.model_validate(payload)
@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
current_user, _ = current_account_with_tenant()
reset_data = AccountService.get_change_email_data(args.token)
if not reset_data:
raise InvalidTokenError()

View File

@ -14,10 +14,16 @@ from pydantic import BaseModel, Field
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
is_admin_or_owner_required,
setup_required,
with_current_tenant_id,
with_current_user_id,
)
from core.plugin.impl.exc import PluginPermissionDeniedError
from graphon.model_runtime.utils.encoders import jsonable_encoder
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
@ -96,17 +102,15 @@ register_schema_models(
)
def _create_endpoint() -> dict[str, bool]:
"""Create a plugin endpoint for the current workspace."""
user, tenant_id = current_account_with_tenant()
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
"""Create a plugin endpoint for the injected workspace and user."""
args = EndpointCreatePayload.model_validate(console_ns.payload)
try:
return {
"success": EndpointService.create_endpoint(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
plugin_unique_identifier=args.plugin_unique_identifier,
name=args.name,
settings=args.settings,
@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]:
raise ValueError(e.description) from e
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
"""Update a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
args = EndpointUpdatePayload.model_validate(console_ns.payload)
return {
"success": EndpointService.update_endpoint(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
endpoint_id=endpoint_id,
name=args.name,
settings=args.settings,
@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
}
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
"""Delete a plugin endpoint identified by the canonical path parameter."""
user, tenant_id = current_account_with_tenant()
return {
"success": EndpointService.delete_endpoint(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
endpoint_id=endpoint_id,
)
}
@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
@console_ns.route("/workspaces/current/endpoints/create")
@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
return _create_endpoint()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
@console_ns.route("/workspaces/current/endpoints/list")
@ -206,9 +210,9 @@ class EndpointListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def get(self, tenant_id: str, user_id: str):
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -218,7 +222,7 @@ class EndpointListApi(Resource):
{
"endpoints": EndpointService.list_endpoints(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
page=page,
page_size=page_size,
)
@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def get(self, tenant_id: str, user_id: str):
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource):
{
"endpoints": EndpointService.list_endpoints_for_single_plugin(
tenant_id=tenant_id,
user_id=user.id,
user_id=user_id,
plugin_id=plugin_id,
page=page,
page_size=page_size,
@ -278,8 +282,10 @@ class EndpointItemApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def delete(self, id: str):
return _delete_endpoint(endpoint_id=id)
@with_current_user_id
@with_current_tenant_id
def delete(self, tenant_id: str, user_id: str, id: str):
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
@console_ns.doc("update_endpoint")
@console_ns.doc(description="Update a plugin endpoint")
@ -295,8 +301,10 @@ class EndpointItemApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def patch(self, id: str):
return _update_endpoint(endpoint_id=id)
@with_current_user_id
@with_current_tenant_id
def patch(self, tenant_id: str, user_id: str, id: str):
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
@console_ns.route("/workspaces/current/endpoints/delete")
@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = EndpointIdPayload.model_validate(console_ns.payload)
return _delete_endpoint(endpoint_id=args.endpoint_id)
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/update")
@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
return _update_endpoint(endpoint_id=args.endpoint_id)
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
@console_ns.route("/workspaces/current/endpoints/enable")
@ -370,14 +382,14 @@ class EndpointEnableApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.enable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
)
}
@ -397,13 +409,13 @@ class EndpointDisableApi(Resource):
@login_required
@is_admin_or_owner_required
@account_initialization_required
def post(self):
user, tenant_id = current_account_with_tenant()
@with_current_user_id
@with_current_tenant_id
def post(self, tenant_id: str, user_id: str):
args = EndpointIdPayload.model_validate(console_ns.payload)
return {
"success": EndpointService.disable_endpoint(
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
)
}

View File

@ -25,13 +25,15 @@ from controllers.console.wraps import (
cloud_edition_billing_resource_check,
only_edition_enterprise,
setup_required,
with_current_tenant_id,
with_current_user,
)
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import TimestampField, dump_response, to_timestamp
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
from libs.login import login_required
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService, SubscriptionPlan
from services.enterprise.enterprise_service import EnterpriseService
@ -153,8 +155,9 @@ class TenantListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account):
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
@ -228,11 +231,11 @@ class TenantApi(Resource):
@login_required
@account_initialization_required
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
def post(self):
@with_current_user
def post(self, current_user: Account):
if request.path == "/info":
logger.warning("Deprecated URL /info was used.")
current_user, _ = current_account_with_tenant()
tenant = current_user.current_tenant
if not tenant:
raise ValueError("No current tenant")
@ -256,8 +259,8 @@ class SwitchWorkspaceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
payload = console_ns.payload or {}
args = SwitchWorkspacePayload.model_validate(payload)
@ -281,8 +284,8 @@ class CustomConfigWorkspaceApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = console_ns.payload or {}
args = WorkspaceCustomConfigPayload.model_validate(payload)
tenant = db.get_or_404(Tenant, current_tenant_id)
@ -308,8 +311,8 @@ class WebappLogoWorkspaceApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("workspace_custom")
def post(self):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account):
# check file
if "file" not in request.files:
raise NoFileUploadedError()
@ -349,8 +352,8 @@ class WorkspaceInfoApi(Resource):
@login_required
@account_initialization_required
# Change workspace name
def post(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def post(self, current_tenant_id: str):
payload = console_ns.payload or {}
args = WorkspaceInfoPayload.model_validate(payload)
@ -372,13 +375,12 @@ class WorkspacePermissionApi(Resource):
@login_required
@account_initialization_required
@only_edition_enterprise
def get(self):
@with_current_tenant_id
def get(self, current_tenant_id: str):
"""
Get workspace permission settings.
Returns permission flags that control workspace features like member invitations and owner transfer.
"""
_, current_tenant_id = current_account_with_tenant()
if not current_tenant_id:
raise ValueError("No current tenant")

View File

@ -4,7 +4,7 @@ import os
import time
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
from typing import Any, Concatenate, overload
from flask import abort, request
from pydantic import BaseModel, ValidationError
@ -37,9 +37,21 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@overload
def account_initialization_required[T, **P, R](
view: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]: ...
@overload
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
def account_initialization_required[R](view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: Any, **kwargs: Any) -> R:
# The overloads keep Resource methods method-aware for pyrefly while
# preserving support for plain functions used in tests and utilities.
# check account initialization
current_user, _ = current_account_with_tenant()
if current_user.status == AccountStatus.UNINITIALIZED:
@ -218,9 +230,21 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
return decorated
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@overload
def setup_required[T, **P, R](
view: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R]: ...
@overload
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
def setup_required[R](view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
def decorated(*args: Any, **kwargs: Any) -> R:
# The overloads keep Resource methods method-aware for pyrefly while
# preserving support for plain functions used in tests and utilities.
# check setup
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
if os.environ.get("INIT_PASSWORD"):
@ -552,7 +576,7 @@ def with_current_user_id[T, **P, R](
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, str(current_user.id), *args, **kwargs)
return view(self, current_user.id, *args, **kwargs)
return decorated

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Concatenate, cast, overload
from flask import Response, current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS
@ -48,7 +48,17 @@ def current_account_with_tenant() -> tuple[Account, str]:
return user, user.current_tenant_id
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
@overload
def login_required[T, **P, R](
func: Callable[Concatenate[T, P], R],
) -> Callable[Concatenate[T, P], R | Response]: ...
@overload
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: ...
def login_required[R](func: Callable[..., R]) -> Callable[..., R | Response]:
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@ -83,7 +93,9 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
"""
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
def decorated_view(*args: Any, **kwargs: Any) -> R | Response:
# The overloads keep Resource methods method-aware for pyrefly while
# preserving support for plain Flask view functions.
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
return current_app.ensure_sync(func)(*args, **kwargs)

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import inspect
from collections.abc import Iterator
from datetime import UTC, datetime
from unittest.mock import MagicMock, PropertyMock, patch
@ -19,31 +21,18 @@ from controllers.console.datasets.data_source import (
DataSourceNotionPreviewApi,
)
from core.rag.index_processor.constant.index_type import IndexStructureType
from models import DataSourceOauthBinding
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
from models import Account, DataSourceOauthBinding
@pytest.fixture
def tenant_ctx():
return (MagicMock(id="u1"), "tenant-1")
def current_user() -> Account:
account = Account(name="Test User", email="u1@example.com")
account.id = "u1"
return account
@pytest.fixture
def patch_tenant(tenant_ctx):
with patch(
"controllers.console.datasets.data_source.current_account_with_tenant",
return_value=tenant_ctx,
):
yield
@pytest.fixture
def mock_engine():
def mock_engine() -> Iterator[None]:
with patch.object(
type(data_source.db),
"engine",
@ -55,12 +44,12 @@ def mock_engine():
class TestDataSourceApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_success(self, app: Flask, patch_tenant):
def test_get_success(self, app: Flask) -> None:
api = DataSourceApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
binding = DataSourceOauthBinding(
tenant_id="tenant-1",
@ -93,7 +82,7 @@ class TestDataSourceApi:
return_value=MagicMock(all=lambda: [binding]),
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
assert response["data"][0] == {
@ -120,9 +109,9 @@ class TestDataSourceApi:
"link": "http://localhost/console/api/oauth/data-source/notion",
}
def test_get_no_bindings(self, app: Flask, patch_tenant):
def test_get_no_bindings(self, app: Flask) -> None:
api = DataSourceApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -131,14 +120,14 @@ class TestDataSourceApi:
return_value=MagicMock(all=lambda: []),
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
def test_patch_enable_binding(self, app: Flask, mock_engine: None) -> None:
api = DataSourceApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
@ -152,14 +141,14 @@ class TestDataSourceApi:
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "enable")
response, status = method(api, "tenant-1", "b1", "enable")
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
def test_patch_disable_binding(self, app: Flask, mock_engine: None) -> None:
api = DataSourceApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
@ -173,14 +162,14 @@ class TestDataSourceApi:
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
response, status = method(api, "b1", "disable")
response, status = method(api, "tenant-1", "b1", "disable")
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
def test_patch_binding_not_found(self, app: Flask, mock_engine: None) -> None:
api = DataSourceApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
with (
app.test_request_context("/"),
@ -191,11 +180,11 @@ class TestDataSourceApi:
mock_session.execute.return_value.scalar_one_or_none.return_value = None
with pytest.raises(NotFound):
method(api, "b1", "enable")
method(api, "tenant-1", "b1", "enable")
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
def test_patch_enable_already_enabled(self, app: Flask, mock_engine: None) -> None:
api = DataSourceApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
binding = MagicMock(id="b1", disabled=False)
@ -208,11 +197,11 @@ class TestDataSourceApi:
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "enable")
method(api, "tenant-1", "b1", "enable")
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
def test_patch_disable_already_disabled(self, app: Flask, mock_engine: None) -> None:
api = DataSourceApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
binding = MagicMock(id="b1", disabled=True)
@ -225,17 +214,17 @@ class TestDataSourceApi:
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
with pytest.raises(ValueError):
method(api, "b1", "disable")
method(api, "tenant-1", "b1", "disable")
class TestDataSourceNotionListApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_credential_not_found(self, app: Flask, patch_tenant):
def test_get_credential_not_found(self, app: Flask, current_user: Account) -> None:
api = DataSourceNotionListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?credential_id=c1"),
@ -245,11 +234,11 @@ class TestDataSourceNotionListApi:
),
):
with pytest.raises(NotFound):
method(api)
method(api, "tenant-1", current_user)
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
def test_get_success_no_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None:
api = DataSourceNotionListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
page = MagicMock(
page_id="p1",
@ -284,13 +273,13 @@ class TestDataSourceNotionListApi:
),
),
):
response, status = method(api)
response, status = method(api, "tenant-1", current_user)
assert status == 200
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None:
api = DataSourceNotionListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
page = MagicMock(
page_id="p1",
@ -337,13 +326,13 @@ class TestDataSourceNotionListApi:
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = [document]
response, status = method(api)
response, status = method(api, "tenant-1", current_user)
assert status == 200
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
def test_get_invalid_dataset_type(self, app: Flask, current_user: Account, mock_engine: None) -> None:
api = DataSourceNotionListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
dataset = MagicMock(data_source_type="other_type")
@ -360,17 +349,17 @@ class TestDataSourceNotionListApi:
patch("controllers.console.datasets.data_source.sessionmaker"),
):
with pytest.raises(ValueError):
method(api)
method(api, "tenant-1", current_user)
class TestDataSourceNotionPreviewApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_preview_success(self, app: Flask, patch_tenant):
def test_get_preview_success(self, app: Flask) -> None:
api = DataSourceNotionPreviewApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
@ -385,21 +374,22 @@ class TestDataSourceNotionPreviewApi:
return_value=extractor,
),
):
response, status = method(api, "p1", "page")
response, status = method(api, "tenant-1", "p1", "page")
assert status == 200
class TestDataSourceNotionIndexingEstimateApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
def test_post_indexing_estimate_success(self, app: Flask) -> None:
api = DataSourceNotionIndexingEstimateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
empty_rules: dict[str, object] = {}
payload: dict[str, object] = {
"notion_info_list": [
{
"workspace_id": "w1",
@ -407,7 +397,7 @@ class TestDataSourceNotionIndexingEstimateApi:
"pages": [{"page_id": "p1", "type": "page"}],
}
],
"process_rule": {"rules": {}},
"process_rule": {"rules": empty_rules},
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
"doc_language": "English",
}
@ -422,19 +412,19 @@ class TestDataSourceNotionIndexingEstimateApi:
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
class TestDataSourceNotionDatasetSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_success(self, app: Flask, patch_tenant):
def test_get_success(self, app: Flask) -> None:
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -455,9 +445,9 @@ class TestDataSourceNotionDatasetSyncApi:
assert status == 200
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
def test_get_dataset_not_found(self, app: Flask) -> None:
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -472,12 +462,12 @@ class TestDataSourceNotionDatasetSyncApi:
class TestDataSourceNotionDocumentSyncApi:
@pytest.fixture
def app(self, flask_app_with_containers: Flask):
def app(self, flask_app_with_containers: Flask) -> Flask:
return flask_app_with_containers
def test_get_success(self, app: Flask, patch_tenant):
def test_get_success(self, app: Flask) -> None:
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -498,9 +488,9 @@ class TestDataSourceNotionDocumentSyncApi:
assert status == 200
def test_get_document_not_found(self, app: Flask, patch_tenant):
def test_get_document_not_found(self, app: Flask) -> None:
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),

View File

@ -1,3 +1,4 @@
import inspect
from unittest.mock import MagicMock, patch
import pytest
@ -23,25 +24,15 @@ from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/?credential_id=cred-1"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
@ -58,20 +49,17 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
response = method(api, "tenant-1", user, "notion")
assert response.status_code == 200
def test_get_no_oauth_config(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
@ -79,20 +67,16 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", user, "notion")
def test_get_without_credential_id_sets_cookie(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_oauth_client",
@ -109,7 +93,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
return_value={"url": "http://auth"},
),
):
response = method(api, "notion")
response = method(api, "tenant-1", user, "notion")
assert response.status_code == 200
assert "context_id" in response.headers.get("Set-Cookie")
@ -118,7 +102,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
@ -160,7 +144,7 @@ class TestDatasourceOAuthCallback:
def test_callback_missing_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
@ -168,7 +152,7 @@ class TestDatasourceOAuthCallback:
def test_callback_invalid_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?context_id=bad"),
@ -183,7 +167,7 @@ class TestDatasourceOAuthCallback:
def test_callback_oauth_config_not_found(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
context = {"user_id": "u", "tenant_id": "t"}
@ -205,7 +189,7 @@ class TestDatasourceOAuthCallback:
def test_callback_reauthorize_existing_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
@ -248,7 +232,7 @@ class TestDatasourceOAuthCallback:
def test_callback_context_id_from_cookie(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
oauth_response = MagicMock()
oauth_response.credentials = {"token": "abc"}
@ -292,40 +276,32 @@ class TestDatasourceOAuthCallback:
class TestDatasourceAuth:
def test_post_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credentials": {"key": "val"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 200
def test_post_invalid_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credentials": {"key": "bad"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"add_datasource_api_key_provider",
@ -333,63 +309,53 @@ class TestDatasourceAuth:
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", "notion")
def test_get_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", user, "notion")
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", "notion")
def test_get_empty_list(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(id="user-1")
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"list_datasource_credentials",
return_value=[],
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", user, "notion")
assert status == 200
assert response["result"] == []
@ -398,136 +364,112 @@ class TestDatasourceAuth:
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credential_id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 200
def test_delete_missing_credential_id(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", "notion")
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credential_id": "id", "credentials": {"k": "v"}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 201
def test_update_with_credentials_none(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credential_id": "id", "credentials": None}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
),
):
_, status = method(api, "notion")
_, status = method(api, "tenant-1", "notion")
assert status == 201
def test_update_with_empty_credentials_dict(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credential_id": "id", "credentials": {}}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_credentials",
return_value=None,
) as update_mock,
):
_, status = method(api, "notion")
_, status = method(api, "tenant-1", "notion")
update_mock.assert_called_once()
assert status == 201
@ -536,62 +478,50 @@ class TestDatasourceAuthUpdateApi:
class TestDatasourceAuthListApi:
def test_list_success(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
def test_auth_list_empty(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_all_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[],
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
assert response["result"] == []
@ -600,21 +530,17 @@ class TestDatasourceAuthListApi:
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"get_hard_code_datasource_credentials",
return_value=[{"id": "1"}],
),
):
response, status = method(api)
response, status = method(api, "tenant-1")
assert status == 200
@ -622,73 +548,61 @@ class TestDatasourceHardCodeAuthListApi:
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"client_params": {}, "enable_oauth_custom_client": True}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 200
def test_delete_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"remove_oauth_custom_client_params",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 200
def test_post_empty_payload(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
),
):
_, status = method(api, "notion")
_, status = method(api, "tenant-1", "notion")
assert status == 200
def test_post_disabled_flag(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"client_params": {"a": 1},
@ -698,17 +612,13 @@ class TestDatasourceAuthOauthCustomClient:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"setup_oauth_custom_client_params",
return_value=None,
) as setup_mock,
):
_, status = method(api, "notion")
_, status = method(api, "tenant-1", "notion")
setup_mock.assert_called_once()
assert status == 200
@ -717,72 +627,60 @@ class TestDatasourceAuthOauthCustomClient:
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"id": "cred-1"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"set_default_datasource_provider",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 200
def test_default_missing_id(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", "notion")
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"credential_id": "id", "name": "New Name"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
DatasourceProviderService,
"update_datasource_provider_name",
return_value=None,
),
):
response, status = method(api, "notion")
response, status = method(api, "tenant-1", "notion")
assert status == 200
def test_update_name_too_long(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"credential_id": "id",
@ -792,27 +690,19 @@ class TestDatasourceUpdateProviderNameApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", "notion")
def test_update_name_missing_credential_id(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"name": "Valid"}
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
):
with pytest.raises(ValueError):
method(api, "notion")
method(api, "tenant-1", "notion")

View File

@ -11,7 +11,7 @@ from flask import Flask
from controllers.console.datasets import data_source as module
from controllers.console.datasets.data_source import DataSourceApi, DataSourceNotionListApi
from models import DataSourceOauthBinding
from models import Account, DataSourceOauthBinding
ControllerMethod = Callable[..., tuple[dict[str, object], int]]
@ -28,13 +28,13 @@ def flask_app() -> Flask:
@pytest.fixture
def tenant_context() -> tuple[MagicMock, str]:
return MagicMock(id="user-1"), "tenant-1"
def current_user() -> Account:
account = Account(name="Test User", email="user-1@example.com")
account.id = "user-1"
return account
def test_get_data_source_integrates_serializes_orm_binding(
flask_app: Flask, tenant_context: tuple[MagicMock, str]
) -> None:
def test_get_data_source_integrates_serializes_orm_binding(flask_app: Flask) -> None:
binding = DataSourceOauthBinding(
tenant_id="tenant-1",
access_token="token",
@ -61,10 +61,9 @@ def test_get_data_source_integrates_serializes_orm_binding(
with (
flask_app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=tenant_context),
patch.object(module.db.session, "scalars", return_value=MagicMock(all=lambda: [binding])),
):
response, status = unwrap(DataSourceApi().get)(DataSourceApi())
response, status = unwrap(DataSourceApi().get)(DataSourceApi(), "tenant-1")
assert status == 200
assert response == {
@ -96,23 +95,18 @@ def test_get_data_source_integrates_serializes_orm_binding(
}
def test_get_data_source_integrates_preserves_empty_list_when_no_binding(
flask_app: Flask, tenant_context: tuple[MagicMock, str]
) -> None:
def test_get_data_source_integrates_preserves_empty_list_when_no_binding(flask_app: Flask) -> None:
with (
flask_app.test_request_context("/"),
patch.object(module, "current_account_with_tenant", return_value=tenant_context),
patch.object(module.db.session, "scalars", return_value=MagicMock(all=lambda: [])),
):
response, status = unwrap(DataSourceApi().get)(DataSourceApi())
response, status = unwrap(DataSourceApi().get)(DataSourceApi(), "tenant-1")
assert status == 200
assert response == {"data": []}
def test_notion_pre_import_pages_serializes_frontend_list_shape(
flask_app: Flask, tenant_context: tuple[MagicMock, str]
) -> None:
def test_notion_pre_import_pages_serializes_frontend_list_shape(flask_app: Flask, current_user: Account) -> None:
page = MagicMock(
page_id="page-1",
page_name="Page",
@ -137,7 +131,6 @@ def test_notion_pre_import_pages_serializes_frontend_list_shape(
with (
flask_app.test_request_context("/?credential_id=credential-1"),
patch.object(module, "current_account_with_tenant", return_value=tenant_context),
patch.object(
module.DatasourceProviderService,
"get_datasource_credentials",
@ -147,7 +140,7 @@ def test_notion_pre_import_pages_serializes_frontend_list_shape(
patch.object(module, "sessionmaker"),
patch("core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", return_value=runtime),
):
response, status = unwrap(DataSourceNotionListApi().get)(DataSourceNotionListApi())
response, status = unwrap(DataSourceNotionListApi().get)(DataSourceNotionListApi(), "tenant-1", current_user)
assert status == 200
assert response == {

View File

@ -1,3 +1,4 @@
import inspect
from unittest.mock import MagicMock, patch
import pytest
@ -39,12 +40,6 @@ from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_serializable_document(**overrides):
attrs = {
"id": "doc-1",
@ -125,11 +120,7 @@ def tenant_ctx():
@pytest.fixture
def patch_tenant(tenant_ctx):
with patch(
"controllers.console.datasets.datasets_document.current_account_with_tenant",
return_value=tenant_ctx,
):
yield
return tenant_ctx
@pytest.fixture
@ -173,16 +164,18 @@ def patch_permission():
class TestGetProcessRuleApi:
def test_get_default_success(self, app: Flask, patch_tenant):
api = GetProcessRuleApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, _ = patch_tenant
with app.test_request_context("/"):
response = method(api)
response = method(api, user)
assert "rules" in response
def test_get_with_document_dataset_not_found(self, app: Flask, patch_tenant):
api = GetProcessRuleApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, _ = patch_tenant
document = MagicMock(dataset_id="ds-1")
@ -198,13 +191,14 @@ class TestGetProcessRuleApi:
),
):
with pytest.raises(NotFound):
method(api)
method(api, user)
class TestDatasetDocumentListApi:
def test_get_with_fetch_true_counts_segments(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
api = DatasetDocumentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
doc = make_serializable_document()
pagination = MagicMock(items=[doc], total=1)
@ -224,7 +218,7 @@ class TestDatasetDocumentListApi:
return_value=None,
),
):
resp = method(api, "ds-1")
resp = method(api, tenant_id, user, "ds-1")
assert resp["data"][0]["id"] == "doc-1"
assert resp["data"][0]["completed_segments"] == 2
@ -234,7 +228,8 @@ class TestDatasetDocumentListApi:
self, app: Flask, patch_tenant, patch_dataset, patch_permission
):
api = DatasetDocumentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
pagination = MagicMock(items=[make_serializable_document()], total=1)
@ -253,13 +248,14 @@ class TestDatasetDocumentListApi:
return_value=None,
),
):
resp = method(api, "ds-1")
resp = method(api, tenant_id, user, "ds-1")
assert resp["total"] == 1
def test_get_success(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
api = DatasetDocumentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
pagination = MagicMock(items=[make_serializable_document()], total=1)
@ -274,7 +270,7 @@ class TestDatasetDocumentListApi:
return_value=None,
),
):
response = method(api, "ds-1")
response = method(api, tenant_id, user, "ds-1")
assert response["total"] == 1
assert response["data"][0]["id"] == "doc-1"
@ -283,7 +279,8 @@ class TestDatasetDocumentListApi:
def test_post_success(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
api = DatasetDocumentListApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, _ = patch_tenant
payload = {"indexing_technique": "economy"}
created_dataset = make_dataset()
@ -306,7 +303,7 @@ class TestDatasetDocumentListApi:
),
patch("models.dataset.db.session.scalar", return_value=0),
):
response = method(api, "ds-1")
response = method(api, user, "ds-1")
assert "documents" in response
assert response["dataset"]["id"] == "ds-1"
@ -318,28 +315,25 @@ class TestDatasetDocumentListApi:
def test_post_forbidden(self, app: Flask):
api = DatasetDocumentListApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/", json={}),
patch.object(type(console_ns), "payload", {}),
patch(
"controllers.console.datasets.datasets_document.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_document.DatasetService.get_dataset",
return_value=MagicMock(),
),
):
with pytest.raises(Forbidden):
method(api, "ds-1")
method(api, user, "ds-1")
def test_get_with_fetch_true_and_invalid_fetch(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
api = DatasetDocumentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
pagination = MagicMock(items=[make_serializable_document()], total=1)
@ -354,13 +348,14 @@ class TestDatasetDocumentListApi:
return_value=None,
),
):
response = method(api, "ds-1")
response = method(api, tenant_id, user, "ds-1")
assert response["total"] == 1
def test_get_sort_hit_count(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
api = DatasetDocumentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
pagination = MagicMock(items=[], total=0)
@ -375,7 +370,7 @@ class TestDatasetDocumentListApi:
return_value=None,
),
):
response = method(api, "ds-1")
response = method(api, tenant_id, user, "ds-1")
assert response["total"] == 0
@ -383,7 +378,8 @@ class TestDatasetDocumentListApi:
class TestDatasetInitApi:
def test_post_success_serializes_created_dataset_and_documents(self, app: Flask, patch_tenant):
api = DatasetInitApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, tenant_id = patch_tenant
payload = {"indexing_technique": "economy"}
created_dataset = make_dataset()
@ -402,7 +398,7 @@ class TestDatasetInitApi:
),
patch("models.dataset.db.session.scalar", return_value=0),
):
response = method(api)
response = method(api, tenant_id, user)
assert response["dataset"]["id"] == "ds-1"
assert response["documents"][0]["id"] == "doc-init"
@ -414,7 +410,8 @@ class TestDatasetInitApi:
class TestDocumentApi:
def test_get_success(self, app: Flask, patch_tenant):
api = DocumentApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(dataset_process_rule=None)
@ -426,21 +423,23 @@ class TestDocumentApi:
return_value={},
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 200
def test_get_invalid_metadata(self, app: Flask, patch_tenant):
api = DocumentApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
with app.test_request_context("/?metadata=wrong"), patch.object(api, "get_document", return_value=MagicMock()):
with pytest.raises(InvalidMetadataError):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
def test_delete_success(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
user, tenant_id = patch_tenant
with (
app.test_request_context("/"),
@ -454,13 +453,14 @@ class TestDocumentApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 204
def test_delete_indexing_error(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
user, tenant_id = patch_tenant
with (
app.test_request_context("/"),
@ -475,13 +475,14 @@ class TestDocumentApi:
),
):
with pytest.raises(DocumentIndexingError):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
class TestDocumentDownloadApi:
def test_download_success(self, app: Flask, patch_tenant):
api = DocumentDownloadApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock()
@ -493,7 +494,7 @@ class TestDocumentDownloadApi:
return_value="url",
),
):
response = method(api, "ds-1", "doc-1")
response = method(api, tenant_id, user, "ds-1", "doc-1")
assert response["url"] == "url"
@ -501,24 +502,21 @@ class TestDocumentDownloadApi:
class TestDocumentProcessingApi:
def test_processing_forbidden_when_not_editor(self, app: Flask):
api = DocumentProcessingApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user = MagicMock(is_dataset_editor=False)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_document.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(api, "get_document", return_value=MagicMock()),
):
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1", "pause")
method(api, "tenant-1", user, "ds-1", "doc-1", "pause")
def test_resume_from_error_state(self, app: Flask, patch_tenant):
api = DocumentProcessingApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, tenant_id = patch_tenant
doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True)
@ -530,13 +528,14 @@ class TestDocumentProcessingApi:
return_value=None,
),
):
_, status = method(api, "ds-1", "doc-1", "resume")
_, status = method(api, tenant_id, user, "ds-1", "doc-1", "resume")
assert status == 200
def test_resume_success(self, app: Flask, patch_tenant):
api = DocumentProcessingApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, tenant_id = patch_tenant
document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True)
@ -548,13 +547,14 @@ class TestDocumentProcessingApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1", "resume")
response, status = method(api, tenant_id, user, "ds-1", "doc-1", "resume")
assert status == 200
def test_pause_success(self, app: Flask, patch_tenant):
api = DocumentProcessingApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, tenant_id = patch_tenant
document = MagicMock(indexing_status="indexing")
@ -566,25 +566,27 @@ class TestDocumentProcessingApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1", "pause")
response, status = method(api, tenant_id, user, "ds-1", "doc-1", "pause")
assert status == 200
def test_pause_invalid(self, app: Flask, patch_tenant):
api = DocumentProcessingApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, tenant_id = patch_tenant
document = MagicMock(indexing_status=IndexingStatus.COMPLETED)
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
with pytest.raises(InvalidActionError):
method(api, "ds-1", "doc-1", "pause")
method(api, tenant_id, user, "ds-1", "doc-1", "pause")
class TestDocumentMetadataApi:
def test_put_metadata_schema_filtering(self, app: Flask, patch_tenant):
api = DocumentMetadataApi()
method = unwrap(api.put)
method = inspect.unwrap(api.put)
user, tenant_id = patch_tenant
doc = MagicMock()
@ -607,13 +609,14 @@ class TestDocumentMetadataApi:
return_value=None,
),
):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
assert doc.doc_metadata == {"amount": 10}
def test_put_success(self, app: Flask, patch_tenant):
api = DocumentMetadataApi()
method = unwrap(api.put)
method = inspect.unwrap(api.put)
user, tenant_id = patch_tenant
document = MagicMock()
@ -631,21 +634,23 @@ class TestDocumentMetadataApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 200
def test_put_invalid_payload(self, app: Flask, patch_tenant):
api = DocumentMetadataApi()
method = unwrap(api.put)
method = inspect.unwrap(api.put)
user, tenant_id = patch_tenant
with app.test_request_context("/", json={}), patch.object(api, "get_document", return_value=MagicMock()):
with pytest.raises(ValueError):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
def test_put_invalid_doc_type(self, app: Flask, patch_tenant):
api = DocumentMetadataApi()
method = unwrap(api.put)
method = inspect.unwrap(api.put)
user, tenant_id = patch_tenant
payload = {"doc_type": "invalid", "doc_metadata": {}}
@ -658,13 +663,14 @@ class TestDocumentMetadataApi:
),
):
with pytest.raises(ValueError):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
class TestDocumentStatusApi:
def test_patch_success(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentStatusApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, _ = patch_tenant
with (
app.test_request_context("/?document_id=doc-1"),
@ -681,13 +687,14 @@ class TestDocumentStatusApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "enable")
response, status = method(api, user, "ds-1", "enable")
assert status == 200
def test_patch_invalid_action(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentStatusApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, _ = patch_tenant
with (
app.test_request_context("/?document_id=doc-1"),
@ -705,13 +712,13 @@ class TestDocumentStatusApi:
),
):
with pytest.raises(InvalidActionError):
method(api, "ds-1", "enable")
method(api, user, "ds-1", "enable")
class TestDocumentRetryApi:
def test_retry_archived_document_skipped(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentRetryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"document_ids": ["doc-1"]}
@ -739,7 +746,7 @@ class TestDocumentRetryApi:
def test_retry_success(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentRetryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"document_ids": ["doc-1"]}
@ -768,7 +775,7 @@ class TestDocumentRetryApi:
def test_retry_skips_completed_document(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentRetryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"document_ids": ["doc-1"]}
@ -795,7 +802,7 @@ class TestDocumentRetryApi:
class TestDocumentPipelineExecutionLogApi:
def test_get_log_success(self, app: Flask, patch_tenant, patch_dataset):
api = DocumentPipelineExecutionLogApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
log = MagicMock(
datasource_info="{}",
@ -823,7 +830,8 @@ class TestDocumentPipelineExecutionLogApi:
class TestDocumentGenerateSummaryApi:
def test_generate_summary_missing_documents(self, app: Flask, patch_tenant, patch_permission):
api = DocumentGenerateSummaryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, _ = patch_tenant
dataset = MagicMock(
indexing_technique="high_quality",
@ -845,11 +853,12 @@ class TestDocumentGenerateSummaryApi:
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
method(api, user, "ds-1")
def test_generate_not_enabled(self, app: Flask, patch_tenant, patch_permission):
api = DocumentGenerateSummaryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, _ = patch_tenant
dataset = MagicMock(indexing_technique="high_quality", summary_index_setting={"enable": False})
@ -864,11 +873,12 @@ class TestDocumentGenerateSummaryApi:
),
):
with pytest.raises(ValueError):
method(api, "ds-1")
method(api, user, "ds-1")
def test_generate_summary_success_with_qa_skip(self, app: Flask, patch_tenant, patch_permission):
api = DocumentGenerateSummaryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, _ = patch_tenant
dataset = MagicMock(
indexing_technique="high_quality",
@ -896,7 +906,7 @@ class TestDocumentGenerateSummaryApi:
return_value=None,
),
):
response, status = method(api, "ds-1")
response, status = method(api, user, "ds-1")
assert status == 200
@ -904,7 +914,8 @@ class TestDocumentGenerateSummaryApi:
class TestDocumentSummaryStatusApi:
def test_get_success(self, app: Flask, patch_tenant, patch_permission):
api = DocumentSummaryStatusApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, _ = patch_tenant
with (
app.test_request_context("/"),
@ -917,7 +928,7 @@ class TestDocumentSummaryStatusApi:
return_value={"total_segments": 0},
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, user, "ds-1", "doc-1")
assert status == 200
@ -925,7 +936,8 @@ class TestDocumentSummaryStatusApi:
class TestDocumentIndexingEstimateApi:
def test_indexing_estimate_file_not_found(self, app: Flask, patch_tenant):
api = DocumentIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(
indexing_status=IndexingStatus.INDEXING,
@ -945,11 +957,12 @@ class TestDocumentIndexingEstimateApi:
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
def test_indexing_estimate_generic_exception(self, app: Flask, patch_tenant):
api = DocumentIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(
indexing_status=IndexingStatus.INDEXING,
@ -982,36 +995,38 @@ class TestDocumentIndexingEstimateApi:
),
):
with pytest.raises(IndexingEstimateError):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
def test_get_finished(self, app: Flask, patch_tenant):
api = DocumentIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(indexing_status=IndexingStatus.COMPLETED)
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
with pytest.raises(DocumentAlreadyFinishedError):
method(api, "ds-1", "doc-1")
method(api, tenant_id, user, "ds-1", "doc-1")
class TestDocumentBatchDownloadZipApi:
def test_post_no_documents(self, app: Flask, patch_tenant):
api = DocumentBatchDownloadZipApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, tenant_id = patch_tenant
payload: dict[str, list[str]] = {"document_ids": []}
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(ValueError):
method(api, "ds-1")
method(api, tenant_id, user, "ds-1")
class TestDatasetDocumentListApiDelete:
def test_delete_success(self, app: Flask, patch_tenant, patch_dataset):
"""Test successful deletion of documents"""
api = DatasetDocumentListApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with (
app.test_request_context("/?document_id=doc-1&document_id=doc-2"),
@ -1031,7 +1046,7 @@ class TestDatasetDocumentListApiDelete:
def test_delete_indexing_error(self, app: Flask, patch_tenant, patch_dataset):
"""Test deletion with indexing error"""
api = DatasetDocumentListApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with (
app.test_request_context("/?document_id=doc-1"),
@ -1050,7 +1065,7 @@ class TestDatasetDocumentListApiDelete:
def test_delete_dataset_not_found(self, app: Flask, patch_tenant):
"""Test deletion when dataset not found"""
api = DatasetDocumentListApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with (
app.test_request_context("/?document_id=doc-1"),
@ -1066,7 +1081,8 @@ class TestDatasetDocumentListApiDelete:
class TestDocumentBatchIndexingEstimateApi:
def test_batch_indexing_estimate_website(self, app: Flask, patch_tenant):
api = DocumentBatchIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
doc = MagicMock(
indexing_status=IndexingStatus.INDEXING,
@ -1089,13 +1105,14 @@ class TestDocumentBatchIndexingEstimateApi:
return_value=MagicMock(model_dump=lambda: {"tokens": 2}),
),
):
resp, status = method(api, "ds-1", "batch-1")
resp, status = method(api, tenant_id, user, "ds-1", "batch-1")
assert status == 200
def test_batch_indexing_estimate_notion(self, app: Flask, patch_tenant):
api = DocumentBatchIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
doc = MagicMock(
indexing_status=IndexingStatus.INDEXING,
@ -1117,13 +1134,14 @@ class TestDocumentBatchIndexingEstimateApi:
return_value=MagicMock(model_dump=lambda: {"tokens": 1}),
),
):
resp, status = method(api, "ds-1", "batch-1")
resp, status = method(api, tenant_id, user, "ds-1", "batch-1")
assert status == 200
def test_batch_estimate_unsupported_datasource(self, app: Flask, patch_tenant):
api = DocumentBatchIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(
indexing_status=IndexingStatus.INDEXING,
@ -1134,22 +1152,24 @@ class TestDocumentBatchIndexingEstimateApi:
with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]):
with pytest.raises(ValueError):
method(api, "ds-1", "batch-1")
method(api, tenant_id, user, "ds-1", "batch-1")
def test_get_batch_estimate_invalid_batch(self, app: Flask, patch_tenant):
"""Test batch estimation with invalid batch"""
api = DocumentBatchIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()):
with pytest.raises(NotFound):
method(api, "ds-1", "invalid-batch")
method(api, tenant_id, user, "ds-1", "invalid-batch")
class TestDocumentBatchIndexingStatusApi:
def test_get_batch_status_success_serializes_status_shape(self, app: Flask, patch_tenant):
api = DocumentBatchIndexingStatusApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, _ = patch_tenant
document = MagicMock(
id="doc-1",
@ -1173,7 +1193,7 @@ class TestDocumentBatchIndexingStatusApi:
side_effect=[2, 3],
),
):
response = method(api, "ds-1", "batch-1")
response = method(api, user, "ds-1", "batch-1")
assert response == {
"data": [
@ -1197,17 +1217,19 @@ class TestDocumentBatchIndexingStatusApi:
def test_get_batch_status_invalid_batch(self, app: Flask, patch_tenant):
"""Test batch status with invalid batch"""
api = DocumentBatchIndexingStatusApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, _ = patch_tenant
with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()):
with pytest.raises(NotFound):
method(api, "ds-1", "invalid-batch")
method(api, user, "ds-1", "invalid-batch")
class TestDocumentIndexingStatusApi:
def test_get_status_success_serializes_status_shape(self, app: Flask, patch_tenant):
api = DocumentIndexingStatusApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(
id="doc-1",
@ -1231,7 +1253,7 @@ class TestDocumentIndexingStatusApi:
side_effect=[1, 4],
),
):
response = method(api, "ds-1", "doc-1")
response = method(api, tenant_id, user, "ds-1", "doc-1")
assert response["id"] == "doc-1"
assert response["indexing_status"] == "indexing"
@ -1241,17 +1263,19 @@ class TestDocumentIndexingStatusApi:
def test_get_status_document_not_found(self, app: Flask, patch_tenant):
"""Test getting status for non-existent document"""
api = DocumentIndexingStatusApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
with app.test_request_context("/"), patch.object(api, "get_document", side_effect=NotFound()):
with pytest.raises(NotFound):
method(api, "ds-1", "invalid-doc")
method(api, tenant_id, user, "ds-1", "invalid-doc")
class TestDocumentRenameApi:
def test_post_success_serializes_document_shape(self, app: Flask, patch_tenant):
api = DocumentRenameApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, _ = patch_tenant
payload = {"name": "Renamed Document"}
renamed_document = make_document(id="doc-renamed", name="Renamed Document")
@ -1273,7 +1297,7 @@ class TestDocumentRenameApi:
),
patch("models.dataset.db.session.scalar", return_value=0),
):
response = method(api, "ds-1", "doc-1")
response = method(api, user, "ds-1", "doc-1")
assert response["id"] == "doc-renamed"
assert response["name"] == "Renamed Document"
@ -1286,7 +1310,8 @@ class TestDocumentApiMetadata:
def test_get_with_only_option(self, app: Flask, patch_tenant):
"""Test get with 'only' metadata option"""
api = DocumentApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(dataset_process_rule=None, doc_metadata_details=[])
@ -1298,14 +1323,15 @@ class TestDocumentApiMetadata:
return_value={},
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 200
def test_get_with_without_option(self, app: Flask, patch_tenant):
"""Test get with 'without' metadata option"""
api = DocumentApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(dataset_process_rule=None)
@ -1317,7 +1343,7 @@ class TestDocumentApiMetadata:
return_value={},
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 200
@ -1326,7 +1352,8 @@ class TestDocumentGenerateSummaryApiSuccess:
def test_generate_not_enabled_high_quality(self, app: Flask, patch_tenant, patch_permission):
"""Test summary generation on non-high-quality dataset"""
api = DocumentGenerateSummaryApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user, _ = patch_tenant
dataset = MagicMock(indexing_technique="economy", summary_index_setting={"enable": True})
@ -1341,26 +1368,28 @@ class TestDocumentGenerateSummaryApiSuccess:
),
):
with pytest.raises(ValueError):
method(api, "ds-1")
method(api, user, "ds-1")
class TestDocumentProcessingApiResume:
def test_resume_invalid_status(self, app: Flask, patch_tenant):
"""Test resume on non-paused document"""
api = DocumentProcessingApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user, tenant_id = patch_tenant
document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False)
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
with pytest.raises(InvalidActionError):
method(api, "ds-1", "doc-1", "resume")
method(api, tenant_id, user, "ds-1", "doc-1", "resume")
class TestDocumentPermissionCases:
def test_document_batch_get_permission_denied(self, app: Flask, patch_tenant):
api = DocumentBatchIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
with (
app.test_request_context("/"),
@ -1374,11 +1403,12 @@ class TestDocumentPermissionCases:
),
):
with pytest.raises(Forbidden):
method(api, "ds-1", "batch-1")
method(api, tenant_id, user, "ds-1", "batch-1")
def test_document_batch_get_documents_not_found(self, app: Flask, patch_tenant):
api = DocumentBatchIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
with (
app.test_request_context("/"),
@ -1392,7 +1422,7 @@ class TestDocumentPermissionCases:
),
patch.object(api, "get_batch_documents", return_value=None),
):
response, status = method(api, "ds-1", "batch-1")
response, status = method(api, tenant_id, user, "ds-1", "batch-1")
assert status == 200
assert response == {
@ -1405,7 +1435,7 @@ class TestDocumentPermissionCases:
def test_document_tenant_mismatch(self, app: Flask):
api = DocumentApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(is_dataset_editor=True)
document = MagicMock(
@ -1415,13 +1445,9 @@ class TestDocumentPermissionCases:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_document.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_document.DatasetService.get_dataset",
return_value=MagicMock(), # ✅ prevents real DB call
return_value=MagicMock(),
),
patch(
"controllers.console.datasets.datasets_document.DocumentService.get_document",
@ -1433,11 +1459,12 @@ class TestDocumentPermissionCases:
),
):
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_process_rule_get_by_document_success(self, app: Flask, patch_tenant):
api = GetProcessRuleApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, _ = patch_tenant
document = MagicMock(dataset_id="ds-1")
process_rule = MagicMock(mode="custom", rules_dict={"a": 1})
@ -1461,7 +1488,7 @@ class TestDocumentPermissionCases:
return_value=process_rule,
),
):
result = method(api)
result = method(api, user)
if isinstance(result, tuple):
response, status = result
@ -1473,16 +1500,13 @@ class TestDocumentPermissionCases:
def test_process_rule_permission_denied(self, app: Flask):
api = GetProcessRuleApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock(is_dataset_editor=True)
document = MagicMock(dataset_id="ds-1")
with (
app.test_request_context("/?document_id=doc-1"),
patch(
"controllers.console.datasets.datasets_document.current_account_with_tenant",
return_value=(MagicMock(is_dataset_editor=True), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_document.db.get_or_404",
return_value=document,
@ -1497,14 +1521,15 @@ class TestDocumentPermissionCases:
),
):
with pytest.raises(Forbidden):
method(api)
method(api, user)
class TestDocumentListAdvancedCases:
def test_document_list_with_multiple_sort_options(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
"""Test document list with different sort options"""
api = DatasetDocumentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
pagination = MagicMock(items=[make_serializable_document()], total=1)
@ -1519,14 +1544,15 @@ class TestDocumentListAdvancedCases:
return_value=None,
),
):
response = method(api, "ds-1")
response = method(api, tenant_id, user, "ds-1")
assert response["total"] == 1
def test_document_metadata_with_schema_validation(self, app: Flask, patch_tenant):
"""Test document metadata update with schema validation"""
api = DocumentMetadataApi()
method = unwrap(api.put)
method = inspect.unwrap(api.put)
user, tenant_id = patch_tenant
doc = MagicMock()
payload = {
@ -1548,7 +1574,7 @@ class TestDocumentListAdvancedCases:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 200
assert doc.doc_metadata == {"amount": 5000, "currency": "USD"}
@ -1557,7 +1583,8 @@ class TestDocumentListAdvancedCases:
class TestDocumentIndexingEdgeCases:
def test_document_indexing_with_extraction_setting(self, app: Flask, patch_tenant):
api = DocumentIndexingEstimateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user, tenant_id = patch_tenant
document = MagicMock(
indexing_status=IndexingStatus.INDEXING,
@ -1586,6 +1613,6 @@ class TestDocumentIndexingEdgeCases:
return_value=MagicMock(model_dump=lambda: {"tokens": 5}),
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
assert status == 200

View File

@ -8,11 +8,11 @@ upload-file documents, and rejects unsupported or missing file cases.
from __future__ import annotations
import importlib
import inspect
import sys
from collections import UserDict
from io import BytesIO
from types import SimpleNamespace
from typing import Any
from zipfile import ZipFile
import pytest
@ -79,7 +79,7 @@ def _mock_document(
upload_file_id: str | None,
) -> SimpleNamespace:
"""Build a minimal document object used by the controller."""
data_source_info_dict: dict[str, Any] | None = None
data_source_info_dict: dict[str, object] | None = None
if upload_file_id is not None:
data_source_info_dict = {"upload_file_id": upload_file_id}
else:
@ -97,7 +97,6 @@ def _wire_common_success_mocks(
*,
module,
monkeypatch: pytest.MonkeyPatch,
current_tenant_id: str,
document_tenant_id: str,
data_source_type: str,
upload_file_id: str | None,
@ -107,9 +106,6 @@ def _wire_common_success_mocks(
"""Patch controller dependencies to create a deterministic test environment."""
import services.dataset_service as dataset_service_module
# Make `current_account_with_tenant()` return a known user + tenant id.
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (_mock_user(), current_tenant_id))
# Return a dataset object and allow permission checks to pass.
monkeypatch.setattr(module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1"))
monkeypatch.setattr(module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None)
@ -124,9 +120,9 @@ def _wire_common_success_mocks(
monkeypatch.setattr(module.DocumentService, "get_document", lambda *_args, **_kwargs: document)
# Mock UploadFile lookup via FileService batch helper.
upload_files_by_id: dict[str, Any] = {}
upload_files_by_id: dict[str, object] = {}
if upload_file_exists and upload_file_id is not None:
upload_files_by_id[str(upload_file_id)] = SimpleNamespace(id=str(upload_file_id))
upload_files_by_id[upload_file_id] = SimpleNamespace(id=upload_file_id)
monkeypatch.setattr(module.FileService, "get_upload_files_by_ids", lambda *_args, **_kwargs: upload_files_by_id)
# Mock signing helper so the returned URL is deterministic.
@ -153,8 +149,6 @@ def test_batch_download_zip_returns_send_file(
) -> None:
"""Ensure batch ZIP download returns a zip attachment via `send_file`."""
# Arrange common permission mocks.
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
monkeypatch.setattr(
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
)
@ -204,7 +198,8 @@ def test_batch_download_zip_returns_send_file(
json={"document_ids": ["11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222"]},
):
api = datasets_document_module.DocumentBatchDownloadZipApi()
result = api.post(dataset_id="ds-1")
method = inspect.unwrap(api.post)
result = method(api, "tenant-123", _mock_user(), dataset_id="ds-1")
# Assert: we returned via send_file with correct mime type and attachment.
assert result["_send_file_kwargs"]["mimetype"] == "application/zip"
@ -222,7 +217,6 @@ def test_batch_download_zip_response_is_openable_zip(
"""Ensure the real Flask `send_file` response body is a valid ZIP that can be opened."""
# Arrange: same controller mocks as the lightweight send_file test, but we keep the real `send_file`.
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
monkeypatch.setattr(
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
)
@ -270,7 +264,8 @@ def test_batch_download_zip_response_is_openable_zip(
json={"document_ids": ["33333333-3333-3333-3333-333333333333", "44444444-4444-4444-4444-444444444444"]},
):
api = datasets_document_module.DocumentBatchDownloadZipApi()
response = api.post(dataset_id="ds-1")
method = inspect.unwrap(api.post)
response = method(api, "tenant-123", _mock_user(), dataset_id="ds-1")
# Assert: response body is a valid ZIP and contains the expected entries.
response.direct_passthrough = False
@ -288,7 +283,6 @@ def test_batch_download_zip_rejects_non_upload_file_document(
) -> None:
"""Ensure batch ZIP download rejects non upload-file documents."""
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
monkeypatch.setattr(
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
)
@ -314,8 +308,9 @@ def test_batch_download_zip_rejects_non_upload_file_document(
json={"document_ids": ["55555555-5555-5555-5555-555555555555"]},
):
api = datasets_document_module.DocumentBatchDownloadZipApi()
method = inspect.unwrap(api.post)
with pytest.raises(NotFound):
api.post(dataset_id="ds-1")
method(api, "tenant-123", _mock_user(), dataset_id="ds-1")
def test_document_download_returns_url_for_upload_file_document(
@ -326,7 +321,6 @@ def test_document_download_returns_url_for_upload_file_document(
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-123",
@ -337,7 +331,8 @@ def test_document_download_returns_url_for_upload_file_document(
# Build a request context then call the resource method directly.
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
result = api.get(dataset_id="ds-1", document_id="doc-1")
method = inspect.unwrap(api.get)
result = method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
assert result == {"url": "https://example.com/signed"}
@ -350,7 +345,6 @@ def test_document_download_rejects_non_upload_file_document(
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="website_crawl",
upload_file_id="file-123",
@ -360,8 +354,9 @@ def test_document_download_rejects_non_upload_file_document(
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
method = inspect.unwrap(api.get)
with pytest.raises(NotFound):
api.get(dataset_id="ds-1", document_id="doc-1")
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
def test_document_download_rejects_missing_upload_file_id(
@ -372,7 +367,6 @@ def test_document_download_rejects_missing_upload_file_id(
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id=None,
@ -382,8 +376,9 @@ def test_document_download_rejects_missing_upload_file_id(
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
method = inspect.unwrap(api.get)
with pytest.raises(NotFound):
api.get(dataset_id="ds-1", document_id="doc-1")
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
def test_document_download_rejects_when_upload_file_record_missing(
@ -394,7 +389,6 @@ def test_document_download_rejects_when_upload_file_record_missing(
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-123",
data_source_type="upload_file",
upload_file_id="file-123",
@ -404,8 +398,9 @@ def test_document_download_rejects_when_upload_file_record_missing(
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
method = inspect.unwrap(api.get)
with pytest.raises(NotFound):
api.get(dataset_id="ds-1", document_id="doc-1")
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
def test_document_download_rejects_tenant_mismatch(
@ -416,7 +411,6 @@ def test_document_download_rejects_tenant_mismatch(
_wire_common_success_mocks(
module=datasets_document_module,
monkeypatch=monkeypatch,
current_tenant_id="tenant-123",
document_tenant_id="tenant-999",
data_source_type="upload_file",
upload_file_id="file-123",
@ -426,5 +420,6 @@ def test_document_download_rejects_tenant_mismatch(
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
api = datasets_document_module.DocumentDownloadApi()
method = inspect.unwrap(api.get)
with pytest.raises(Forbidden):
api.get(dataset_id="ds-1", document_id="doc-1")
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")

View File

@ -1,3 +1,4 @@
import inspect
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
@ -34,12 +35,6 @@ from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDelete
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _segment():
segment = DocumentSegment(
tenant_id="tenant-1",
@ -129,10 +124,11 @@ def test_segment_response_with_summary():
class TestDatasetDocumentSegmentListApi:
def test_get_success(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
dataset = MagicMock()
document = MagicMock()
user = MagicMock()
segment = _segment()
@ -143,10 +139,6 @@ class TestDatasetDocumentSegmentListApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -170,40 +162,34 @@ class TestDatasetDocumentSegmentListApi:
patch("models.dataset.db.session.scalar", return_value=None),
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
assert status == 200
def test_get_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_get_permission_denied(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
dataset = MagicMock()
user = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -214,13 +200,13 @@ class TestDatasetDocumentSegmentListApi:
),
):
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
class TestDatasetDocumentSegmentApi:
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user = MagicMock()
user.is_dataset_editor = True
@ -233,10 +219,6 @@ class TestDatasetDocumentSegmentApi:
with (
app.test_request_context("/?segment_id=s1&segment_id=s2"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -258,14 +240,14 @@ class TestDatasetDocumentSegmentApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1", "enable")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "enable")
assert status == 200
assert response["result"] == "success"
def test_patch_document_indexing_in_progress(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user = MagicMock()
user.is_dataset_editor = True
@ -278,10 +260,6 @@ class TestDatasetDocumentSegmentApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -304,11 +282,11 @@ class TestDatasetDocumentSegmentApi:
),
):
with pytest.raises(InvalidActionError):
method(api, "ds-1", "doc-1", "disable")
method(api, "tenant-1", user, "ds-1", "doc-1", "disable")
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user = MagicMock(is_dataset_editor=True)
@ -322,10 +300,6 @@ class TestDatasetDocumentSegmentApi:
with (
app.test_request_context("/?segment_id=s1"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -348,11 +322,11 @@ class TestDatasetDocumentSegmentApi:
),
):
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1", "enable")
method(api, "tenant-1", user, "ds-1", "doc-1", "enable")
def test_patch_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
user = MagicMock(is_dataset_editor=True)
@ -366,10 +340,6 @@ class TestDatasetDocumentSegmentApi:
with (
app.test_request_context("/?segment_id=s1"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -392,13 +362,13 @@ class TestDatasetDocumentSegmentApi:
),
):
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1", "enable")
method(api, "tenant-1", user, "ds-1", "doc-1", "enable")
class TestDatasetDocumentSegmentAddApi:
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"content": "hello"}
@ -416,10 +386,6 @@ class TestDatasetDocumentSegmentAddApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -447,14 +413,14 @@ class TestDatasetDocumentSegmentAddApi:
patch("models.dataset.db.session.scalar", return_value=None),
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
assert status == 200
assert response["data"]["id"] == "seg-1"
def test_post_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"content": "x"}
@ -471,10 +437,6 @@ class TestDatasetDocumentSegmentAddApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -489,11 +451,11 @@ class TestDatasetDocumentSegmentAddApi:
),
):
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"content": "x"}
@ -510,10 +472,6 @@ class TestDatasetDocumentSegmentAddApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -528,13 +486,13 @@ class TestDatasetDocumentSegmentAddApi:
),
):
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
class TestDatasetDocumentSegmentUpdateApi:
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
payload = {"content": "updated"}
@ -552,10 +510,6 @@ class TestDatasetDocumentSegmentUpdateApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -586,14 +540,14 @@ class TestDatasetDocumentSegmentUpdateApi:
),
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
):
response, status = method(api, "ds-1", "doc-1", "seg-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
assert status == 200
assert "data" in response
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
payload = {"content": "x"}
@ -610,10 +564,6 @@ class TestDatasetDocumentSegmentUpdateApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -632,26 +582,23 @@ class TestDatasetDocumentSegmentUpdateApi:
),
):
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1", "seg-1")
method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
class TestDatasetDocumentSegmentBatchImportApi:
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"upload_file_id": "file-1"}
upload_file = MagicMock(spec=UploadFile)
upload_file.name = "test.csv"
user = MagicMock(id="u1")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -673,45 +620,39 @@ class TestDatasetDocumentSegmentBatchImportApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
assert status == 200
assert response["job_status"] == "waiting"
def test_post_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"upload_file_id": "file-1"}
user = MagicMock(id="u1")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_post_document_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"upload_file_id": "file-1"}
user = MagicMock(id="u1")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -722,21 +663,18 @@ class TestDatasetDocumentSegmentBatchImportApi:
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_post_upload_file_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"upload_file_id": "file-1"}
user = MagicMock(id="u1")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -751,24 +689,21 @@ class TestDatasetDocumentSegmentBatchImportApi:
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_post_invalid_file_type(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"upload_file_id": "file-1"}
upload_file = MagicMock()
upload_file.name = "test.txt"
user = MagicMock(id="u1")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -783,24 +718,21 @@ class TestDatasetDocumentSegmentBatchImportApi:
),
):
with pytest.raises(ValueError):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_post_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"upload_file_id": "file-1"}
upload_file = MagicMock()
upload_file.name = "test.csv"
user = MagicMock(id="u1")
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -818,14 +750,14 @@ class TestDatasetDocumentSegmentBatchImportApi:
side_effect=Exception("redis down"),
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
assert status == 500
assert "error" in response
def test_get_job_not_found_in_redis(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -840,23 +772,19 @@ class TestDatasetDocumentSegmentBatchImportApi:
class TestChildChunkAddApi:
def test_patch_documents_batch_update_payload(self):
api_doc = unwrap(ChildChunkAddApi.patch).__apidoc__
api_doc = getattr(ChildChunkAddApi.patch, "__apidoc__") # noqa: B009
expected_model = ChildChunkBatchUpdatePayload.__name__
assert [model.name for model in api_doc["expect"]] == [expected_model]
def test_get_uses_default_pagination_for_malformed_ints(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
pagination = MagicMock(items=[], total=0, pages=0)
with (
app.test_request_context("/?page=bad&limit="),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -878,7 +806,7 @@ class TestChildChunkAddApi:
return_value=pagination,
) as get_child_chunks,
):
response, status = method(api, "ds-1", "doc-1", "seg-1")
response, status = method(api, "tenant-1", "ds-1", "doc-1", "seg-1")
assert status == 200
assert response["page"] == 1
@ -887,7 +815,7 @@ class TestChildChunkAddApi:
def test_post_success(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"content": "child"}
@ -904,10 +832,6 @@ class TestChildChunkAddApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -929,14 +853,14 @@ class TestChildChunkAddApi:
return_value=child_chunk,
),
):
response, status = method(api, "ds-1", "doc-1", "seg-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
assert status == 200
assert response["data"]["id"] == "cc-1"
def test_post_child_chunk_indexing_error(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"content": "child"}
@ -949,10 +873,6 @@ class TestChildChunkAddApi:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -975,13 +895,13 @@ class TestChildChunkAddApi:
),
):
with pytest.raises(ChildChunkIndexingError):
method(api, "ds-1", "doc-1", "seg-1")
method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
class TestChildChunkUpdateApi:
def test_delete_success(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
user = MagicMock()
user.is_dataset_editor = True
@ -993,10 +913,6 @@ class TestChildChunkUpdateApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1018,14 +934,14 @@ class TestChildChunkUpdateApi:
return_value=None,
),
):
response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1", "cc-1")
assert status == 204
assert response == ""
def test_delete_child_chunk_index_error(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
user = MagicMock(is_dataset_editor=True)
@ -1036,10 +952,6 @@ class TestChildChunkUpdateApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1062,16 +974,17 @@ class TestChildChunkUpdateApi:
),
):
with pytest.raises(ChildChunkDeleteIndexError):
method(api, "ds-1", "doc-1", "seg-1", "cc-1")
method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1", "cc-1")
class TestSegmentListAdvancedCases:
def test_segment_list_with_keyword_filter(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
dataset = MagicMock()
document = MagicMock()
user = MagicMock()
segment = _segment()
@ -1079,10 +992,6 @@ class TestSegmentListAdvancedCases:
with (
app.test_request_context("/?keyword=test"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1106,7 +1015,7 @@ class TestSegmentListAdvancedCases:
patch("models.dataset.db.session.scalar", return_value=None),
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
):
result = method(api, "ds-1", "doc-1")
result = method(api, "tenant-1", user, "ds-1", "doc-1")
if isinstance(result, tuple):
response, status = result
@ -1118,18 +1027,15 @@ class TestSegmentListAdvancedCases:
def test_segment_list_postgres_keyword_filter_handles_scalar_keywords(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
dataset = MagicMock()
document = MagicMock()
user = MagicMock()
pagination = MagicMock(items=[], total=0, pages=0)
with (
app.test_request_context("/?keyword=test"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "11111111-1111-1111-1111-111111111111"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1151,7 +1057,13 @@ class TestSegmentListAdvancedCases:
return_value=pagination,
) as paginate_mock,
):
method(api, "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333")
method(
api,
"11111111-1111-1111-1111-111111111111",
user,
"22222222-2222-2222-2222-222222222222",
"33333333-3333-3333-3333-333333333333",
)
query = paginate_mock.call_args.kwargs["select"]
sql = str(query.compile(compile_kwargs={"literal_binds": True}))
@ -1161,14 +1073,11 @@ class TestSegmentListAdvancedCases:
def test_segment_list_permission_denied(self, app: Flask):
"""Test segment list with permission denied"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=MagicMock(),
@ -1179,33 +1088,30 @@ class TestSegmentListAdvancedCases:
),
):
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_segment_list_dataset_not_found(self, app: Flask):
"""Test segment list with dataset not found"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
class TestSegmentOperationCases:
def test_segment_add_with_provider_token_error(self, app: Flask):
"""Test segment add with provider token not initialized"""
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
@ -1216,10 +1122,6 @@ class TestSegmentOperationCases:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1238,12 +1140,12 @@ class TestSegmentOperationCases:
),
):
with pytest.raises(ProviderTokenNotInitError):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_batch_import_with_document_not_found(self, app: Flask):
"""Test batch import with document not found"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
@ -1253,10 +1155,6 @@ class TestSegmentOperationCases:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1267,12 +1165,12 @@ class TestSegmentOperationCases:
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_batch_import_with_invalid_file(self, app: Flask):
"""Test batch import with invalid file type"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
@ -1284,10 +1182,6 @@ class TestSegmentOperationCases:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1302,11 +1196,11 @@ class TestSegmentOperationCases:
),
):
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
method(api, "tenant-1", user, "ds-1", "doc-1")
def test_batch_import_with_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock(is_dataset_editor=True)
dataset = MagicMock()
@ -1319,10 +1213,6 @@ class TestSegmentOperationCases:
with (
app.test_request_context("/", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
return_value=dataset,
@ -1344,23 +1234,17 @@ class TestSegmentOperationCases:
side_effect=Exception("Task failed"),
),
):
response, status = method(api, "ds-1", "doc-1")
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
assert status == 500
assert "error" in response
def test_batch_import_get_job_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)
user = MagicMock(is_dataset_editor=True)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?job_id=invalid-job"),
patch(
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch(
"controllers.console.datasets.datasets_segments.redis_client.get",
return_value=None,

View File

@ -1,4 +1,4 @@
from importlib import import_module
import inspect
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -16,50 +16,32 @@ from controllers.console.datasets.external import (
ExternalDatasetCreateApi,
ExternalKnowledgeHitTestingApi,
)
from models.account import Account, TenantAccountRole
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
external_controller = import_module("controllers.console.datasets.external")
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def app():
def app() -> Flask:
app = Flask("test_external_dataset")
app.config["TESTING"] = True
return app
@pytest.fixture
def current_user():
user = MagicMock()
def current_user() -> Account:
user = Account(name="Test User", email="user-1@example.com")
user.id = "user-1"
user.is_dataset_editor = True
user.has_edit_permission = True
user.is_dataset_operator = True
user.role = TenantAccountRole.EDITOR
return user
@pytest.fixture(autouse=True)
def mock_auth(monkeypatch: pytest.MonkeyPatch, current_user):
monkeypatch.setattr(
external_controller,
"current_account_with_tenant",
lambda: (current_user, "tenant-1"),
)
class TestExternalApiTemplateListApi:
def test_get_success(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
api_item = MagicMock()
api_item.to_dict.return_value = {"id": "1"}
@ -79,10 +61,10 @@ class TestExternalApiTemplateListApi:
assert resp["data"][0]["id"] == "1"
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
def test_post_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
def test_post_forbidden(self, app: Flask, current_user: Account):
current_user.role = TenantAccountRole.NORMAL
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
@ -92,11 +74,11 @@ class TestExternalApiTemplateListApi:
patch.object(ExternalDatasetService, "validate_api_list"),
):
with pytest.raises(Forbidden):
method(api)
method(api, "tenant-1", current_user)
def test_post_duplicate_name(self, app: Flask):
def test_post_duplicate_name(self, app: Flask, current_user: Account):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"name": "x", "settings": {"k": "v"}}
@ -111,13 +93,13 @@ class TestExternalApiTemplateListApi:
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
method(api, "tenant-1", current_user)
class TestExternalApiTemplateApi:
def test_get_not_found(self, app: Flask):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -128,24 +110,23 @@ class TestExternalApiTemplateApi:
),
):
with pytest.raises(NotFound):
method(api, "api-id")
method(api, "tenant-1", "api-id")
def test_delete_forbidden(self, app: Flask, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
def test_delete_forbidden(self, app: Flask, current_user: Account):
current_user.role = TenantAccountRole.NORMAL
api = ExternalApiTemplateApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with app.test_request_context("/"):
with pytest.raises(Forbidden):
method(api, "api-id")
method(api, "tenant-1", current_user, "api-id")
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app: Flask):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/"),
@ -155,7 +136,7 @@ class TestExternalApiUseCheckApi:
return_value=(True, 2),
) as mock_use_check,
):
response, status = method(api, "api-id")
response, status = method(api, "tenant-1", "api-id")
assert status == 200
assert response == {"is_using": True, "count": 2}
@ -163,9 +144,9 @@ class TestExternalApiUseCheckApi:
class TestExternalDatasetCreateApi:
def test_create_success(self, app: Flask):
def test_create_success(self, app: Flask, current_user: Account):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
@ -203,14 +184,14 @@ class TestExternalDatasetCreateApi:
return_value=dataset,
),
):
_, status = method(api)
_, status = method(api, "tenant-1", current_user)
assert status == 201
def test_create_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
def test_create_forbidden(self, app: Flask, current_user: Account):
current_user.role = TenantAccountRole.NORMAL
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"external_knowledge_api_id": "api",
@ -223,13 +204,13 @@ class TestExternalDatasetCreateApi:
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
):
with pytest.raises(Forbidden):
method(api)
method(api, "tenant-1", current_user)
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app: Flask):
def test_hit_testing_dataset_not_found(self, app: Flask, current_user: Account):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
with (
app.test_request_context("/"),
@ -240,11 +221,11 @@ class TestExternalKnowledgeHitTestingApi:
),
):
with pytest.raises(NotFound):
method(api, "dataset-id")
method(api, current_user, "dataset-id")
def test_hit_testing_success(self, app: Flask):
def test_hit_testing_success(self, app: Flask, current_user: Account):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"query": "hello"}
@ -261,7 +242,7 @@ class TestExternalKnowledgeHitTestingApi:
return_value={"ok": True},
),
):
resp = method(api, "dataset-id")
resp = method(api, current_user, "dataset-id")
assert resp["ok"] is True
@ -269,7 +250,7 @@ class TestExternalKnowledgeHitTestingApi:
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app: Flask):
api = BedrockRetrievalApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"retrieval_setting": {},
@ -293,9 +274,9 @@ class TestBedrockRetrievalApi:
class TestExternalApiTemplateListApiAdvanced:
def test_post_duplicate_name_error(self, app: Flask, mock_auth, current_user):
def test_post_duplicate_name_error(self, app: Flask, current_user: Account):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
@ -309,11 +290,11 @@ class TestExternalApiTemplateListApiAdvanced:
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
method(api, "tenant-1", current_user)
def test_get_with_pagination(self, app: Flask, mock_auth, current_user):
def test_get_with_pagination(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
@ -333,12 +314,12 @@ class TestExternalApiTemplateListApiAdvanced:
class TestExternalDatasetCreateApiAdvanced:
def test_create_forbidden(self, app: Flask, mock_auth, current_user):
def test_create_forbidden(self, app: Flask, current_user: Account):
"""Test creating external dataset without permission"""
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
current_user.is_dataset_editor = False
current_user.role = TenantAccountRole.NORMAL
payload = {
"external_knowledge_api_id": "api-1",
@ -349,14 +330,14 @@ class TestExternalDatasetCreateApiAdvanced:
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(Forbidden):
method(api)
method(api, "tenant-1", current_user)
class TestExternalKnowledgeHitTestingApiAdvanced:
def test_hit_testing_dataset_not_found(self, app: Flask, mock_auth, current_user):
def test_hit_testing_dataset_not_found(self, app: Flask, current_user: Account):
"""Test hit testing on non-existent dataset"""
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"query": "test query",
@ -372,11 +353,11 @@ class TestExternalKnowledgeHitTestingApiAdvanced:
),
):
with pytest.raises(NotFound):
method(api, "ds-1")
method(api, current_user, "ds-1")
def test_hit_testing_with_custom_retrieval_model(self, app: Flask, mock_auth, current_user):
def test_hit_testing_with_custom_retrieval_model(self, app: Flask, current_user: Account):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
dataset = MagicMock()
payload = {
@ -398,15 +379,15 @@ class TestExternalKnowledgeHitTestingApiAdvanced:
return_value={"results": []},
),
):
resp = method(api, "ds-1")
resp = method(api, current_user, "ds-1")
assert resp["results"] == []
class TestBedrockRetrievalApiAdvanced:
def test_bedrock_retrieval_with_invalid_setting(self, app: Flask, mock_auth, current_user):
def test_bedrock_retrieval_with_invalid_setting(self, app: Flask):
api = BedrockRetrievalApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"retrieval_setting": {},

View File

@ -1,8 +1,9 @@
import inspect
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from flask import Flask
from controllers.console.workspace.account import (
AccountDeleteUpdateFeedbackApi,
@ -11,7 +12,7 @@ from controllers.console.workspace.account import (
ChangeEmailSendEmailApi,
CheckEmailUnique,
)
from models import Account, AccountStatus
from models import Account, AccountStatus, Tenant
from services.account_service import AccountService
from services.entities.auth_entities import (
ChangeEmailNewEmailToken,
@ -26,12 +27,16 @@ def app():
app = Flask(__name__)
app.config["TESTING"] = True
app.config["RESTX_MASK_HEADER"] = "X-Fields"
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
setattr(app, "login_manager", SimpleNamespace(load_user_from_request_context=lambda: None)) # noqa: B010
return app
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
def _build_account(email: str, account_id: str = "acc", tenant: Tenant | None = None) -> Account:
if tenant is None:
tenant_obj = Tenant(name="Tenant")
tenant_obj.id = "tenant-id"
else:
tenant_obj = tenant
account = Account(name=account_id, email=email)
account.email = email
account.id = account_id
@ -40,11 +45,6 @@ def _build_account(email: str, account_id: str = "acc", tenant: object | None =
return account
def _set_logged_in_user(account: Account):
g._login_user = account
g._current_tenant = account.current_tenant
def _build_change_email_token(
phase: str,
*,
@ -71,63 +71,44 @@ def _build_change_email_token(
class TestChangeEmailSend:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_old_email_phase_when_request_email_does_not_match_current_user(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_current_account,
mock_db,
app: Flask,
):
from controllers.console.auth.error import InvalidEmailError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("current@example.com", "acc1"), None)
current_user = _build_account("current@example.com", "acc1")
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "other@example.com", "language": "en-US", "phase": "old_email"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
method = inspect.unwrap(ChangeEmailSendEmailApi().post)
with pytest.raises(InvalidEmailError):
ChangeEmailSendEmailApi().post()
method(ChangeEmailSendEmailApi(), current_user)
mock_send_email.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_phase(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
account_id="acc1",
@ -141,8 +122,9 @@ class TestChangeEmailSend:
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
api = ChangeEmailSendEmailApi()
method = inspect.unwrap(api.post)
response = method(api, mock_account)
assert response == {"result": "success", "data": "token-abc"}
mock_send_email.assert_called_once_with(
@ -154,34 +136,23 @@ class TestChangeEmailSend:
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_OLD,
account_id="acc1",
@ -194,37 +165,28 @@ class TestChangeEmailSend:
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailSendEmailApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailSendEmailApi().post()
method(api, mock_account)
mock_send_email.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
account_id="other-account",
@ -237,41 +199,32 @@ class TestChangeEmailSend:
method="POST",
json={"email": "new@example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailSendEmailApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailSendEmailApi().post()
method(api, mock_account)
mock_send_email.assert_not_called()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_validate_with_normalized_email(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_OLD,
@ -286,8 +239,9 @@ class TestChangeEmailValidity:
method="POST",
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
api = ChangeEmailCheckApi()
method = inspect.unwrap(api.post)
response = method(api, mock_account)
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
mock_is_rate_limit.assert_called_once_with("user@example.com")
@ -303,34 +257,24 @@ class TestChangeEmailValidity:
mock_account,
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_upgrade_new_phase_token_to_new_verified(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
current_user = _build_account("old@example.com", "acc")
mock_is_rate_limit.return_value = False
mock_get_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_NEW,
@ -345,8 +289,9 @@ class TestChangeEmailValidity:
method="POST",
json={"email": "new@example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
api = ChangeEmailCheckApi()
method = inspect.unwrap(api.post)
response = method(api, current_user)
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"}
mock_generate_token.assert_called_once_with(
@ -356,37 +301,28 @@ class TestChangeEmailValidity:
email="new@example.com",
old_email="old@example.com",
),
mock_current_account.return_value[0],
current_user,
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_validity_when_token_is_already_verified(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
current_user = _build_account("old@example.com", "acc")
mock_is_rate_limit.return_value = False
mock_get_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
@ -400,41 +336,33 @@ class TestChangeEmailValidity:
method="POST",
json={"email": "old@example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailCheckApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailCheckApi().post()
method(api, current_user)
mock_revoke_token.assert_not_called()
mock_generate_token.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_validity_when_token_account_id_does_not_match_current_user(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
current_user = _build_account("old@example.com", "acc")
mock_is_rate_limit.return_value = False
mock_get_data.return_value = _build_change_email_token(
AccountService.CHANGE_EMAIL_PHASE_NEW,
@ -448,42 +376,33 @@ class TestChangeEmailValidity:
method="POST",
json={"email": "new@example.com", "code": "1234", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailCheckApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailCheckApi().post()
method(api, current_user)
mock_revoke_token.assert_not_called()
mock_generate_token.assert_not_called()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_before_update(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = _build_change_email_token(
@ -500,46 +419,36 @@ class TestChangeEmailReset:
method="POST",
json={"new_email": "New@Example.com", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
ChangeEmailResetApi().post()
api = ChangeEmailResetApi()
method = inspect.unwrap(api.post)
method(api, current_user)
mock_is_freeze.assert_called_once_with("new@example.com")
mock_check_unique.assert_called_once_with("new@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_reset_when_token_phase_is_not_new_verified(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = _build_change_email_token(
@ -554,44 +463,35 @@ class TestChangeEmailReset:
method="POST",
json={"new_email": "attacker@example.com", "token": "token-from-step1"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailResetApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
method(api, current_user)
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app: Flask,
):
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = _build_change_email_token(
@ -606,43 +506,34 @@ class TestChangeEmailReset:
method="POST",
json={"new_email": "attacker@example.com", "token": "token-verified"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailResetApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
method(api, current_user)
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_reset_when_token_account_id_does_not_match_current_user(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app: Flask,
):
from controllers.console.auth.error import InvalidTokenError
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = _build_change_email_token(
@ -657,9 +548,10 @@ class TestChangeEmailReset:
method="POST",
json={"new_email": "new@example.com", "token": "token-verified"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
api = ChangeEmailResetApi()
method = inspect.unwrap(api.post)
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
method(api, current_user)
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
@ -755,25 +647,25 @@ class TestAccountServiceGetChangeEmailData:
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask):
def test_should_normalize_feedback_email(self, mock_update, app: Flask):
with app.test_request_context(
"/account/delete/feedback",
method="POST",
json={"email": "User@Example.com", "feedback": "test"},
):
response = AccountDeleteUpdateFeedbackApi().post()
api = AccountDeleteUpdateFeedbackApi()
method = inspect.unwrap(api.post)
response = method(api)
assert response == {"result": "success"}
mock_update.assert_called_once_with("User@Example.com", "test")
class TestCheckEmailUnique:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask):
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask):
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
@ -782,7 +674,9 @@ class TestCheckEmailUnique:
method="POST",
json={"email": "Case@Test.com"},
):
response = CheckEmailUnique().post()
api = CheckEmailUnique()
method = inspect.unwrap(api.post)
response = method(api)
assert response == {"result": "success"}
mock_is_freeze.assert_called_once_with("case@test.com")

View File

@ -1,3 +1,4 @@
import inspect
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
@ -31,22 +32,29 @@ from controllers.console.workspace.error import (
CurrentPasswordIncorrectError,
InvalidAccountDeletionCodeError,
)
from models import Account
from models.account import AccountStatus
from models.enums import CreatorUserRole
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_account(account_id: str = "u1", *, status: AccountStatus = AccountStatus.ACTIVE) -> Account:
account = Account(name="John", email=f"{account_id}@test.com", status=status)
account.id = account_id
account.avatar = "avatar.png"
account.interface_language = "en-US"
account.interface_theme = "light"
account.timezone = "UTC"
account.last_login_ip = "127.0.0.1"
return account
class TestAccountInitApi:
def test_init_success(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
account = MagicMock(status="inactive")
account = make_account(status=AccountStatus.UNINITIALIZED)
payload = {
"interface_language": "en-US",
"timezone": "UTC",
@ -55,50 +63,35 @@ class TestAccountInitApi:
with (
app.test_request_context("/account/init", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
):
scalar_mock.return_value = MagicMock(status="unused")
resp = method(api)
resp = method(api, account)
assert resp["result"] == "success"
def test_init_already_initialized(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
account = MagicMock(status="active")
account = make_account()
with (
app.test_request_context("/account/init"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
):
with app.test_request_context("/account/init"):
with pytest.raises(AccountAlreadyInitedError):
method(api)
method(api, account)
class TestAccountProfileApi:
def test_get_profile_success(self, app: Flask):
api = AccountProfileApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
user = make_account()
with (
app.test_request_context("/account/profile"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
):
result = method(api)
with app.test_request_context("/account/profile"):
result = method(api, user)
assert result["id"] == "u1"
@ -116,24 +109,15 @@ class TestAccountUpdateApis:
)
def test_update_success(self, app: Flask, api_cls, payload):
api = api_cls()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
user = make_account()
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
):
result = method(api)
result = method(api, user)
assert result["id"] == "u1"
@ -143,10 +127,9 @@ class TestAccountAvatarApiGet:
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
user.id = "acc-owner"
user = make_account("acc-owner")
tenant_id = "tenant-1"
file_id = "550e8400-e29b-41d4-a716-446655440000"
@ -158,27 +141,22 @@ class TestAccountAvatarApiGet:
with (
app.test_request_context(f"/account/avatar?avatar={file_id}"),
patch(
"controllers.console.workspace.account.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file),
patch(
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
return_value="https://signed/example",
) as sign_mock,
):
result = method(api)
result = method(api, tenant_id, user)
assert result == {"avatar_url": "https://signed/example"}
sign_mock.assert_called_once_with(upload_file_id=file_id)
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
user.id = "acc-a"
user = make_account("acc-a")
tenant_id = "tenant-1"
file_id = "550e8400-e29b-41d4-a716-446655440001"
@ -190,10 +168,6 @@ class TestAccountAvatarApiGet:
with (
app.test_request_context(f"/account/avatar?avatar={file_id}"),
patch(
"controllers.console.workspace.account.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file),
patch(
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
@ -201,16 +175,15 @@ class TestAccountAvatarApiGet:
) as sign_mock,
):
with pytest.raises(NotFound):
method(api)
method(api, tenant_id, user)
sign_mock.assert_not_called()
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
user.id = "acc-owner"
user = make_account("acc-owner")
tenant_id = "tenant-1"
file_id = "550e8400-e29b-41d4-a716-446655440002"
@ -222,10 +195,6 @@ class TestAccountAvatarApiGet:
with (
app.test_request_context(f"/account/avatar?avatar={file_id}"),
patch(
"controllers.console.workspace.account.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file),
patch(
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
@ -233,31 +202,26 @@ class TestAccountAvatarApiGet:
) as sign_mock,
):
with pytest.raises(NotFound):
method(api)
method(api, tenant_id, user)
sign_mock.assert_not_called()
def test_get_avatar_https_pass_through_without_signing(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = MagicMock()
user.id = "acc-owner"
user = make_account("acc-owner")
tenant_id = "tenant-1"
external = "https://cdn.example/avatar.png"
with (
app.test_request_context(f"/account/avatar?avatar={external}"),
patch(
"controllers.console.workspace.account.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch(
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
return_value="https://signed/should-not-use",
) as sign_mock,
):
result = method(api)
result = method(api, tenant_id, user)
assert result == {"avatar_url": external}
sign_mock.assert_not_called()
@ -266,7 +230,7 @@ class TestAccountAvatarApiGet:
class TestAccountPasswordApi:
def test_password_success(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"password": "old",
@ -274,63 +238,51 @@ class TestAccountPasswordApi:
"repeat_new_password": "new123",
}
user = MagicMock()
user.id = "u1"
user.name = "John"
user.email = "john@test.com"
user.avatar = "avatar.png"
user.interface_language = "en-US"
user.interface_theme = "light"
user.timezone = "UTC"
user.last_login_ip = "127.0.0.1"
user = make_account()
with (
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
):
result = method(api)
result = method(api, user)
assert result["id"] == "u1"
def test_password_wrong_current(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"password": "bad",
"new_password": "new123",
"repeat_new_password": "new123",
}
user = make_account()
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.update_account_password",
side_effect=ServicePwdError(),
),
):
with pytest.raises(CurrentPasswordIncorrectError):
method(api)
method(api, user)
class TestAccountIntegrateApi:
def test_get_integrates(self, app: Flask):
api = AccountIntegrateApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
account = MagicMock(id="acc1")
account = make_account("acc1")
with (
app.test_request_context("/"),
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
):
scalars_mock.return_value.all.return_value = []
result = method(api)
result = method(api, account)
assert "data" in result
assert len(result["data"]) == 2
@ -339,13 +291,11 @@ class TestAccountIntegrateApi:
class TestAccountDeleteApi:
def test_delete_verify_success(self, app: Flask):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = make_account()
with (
app.test_request_context("/"),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
return_value=("token", "1234"),
@ -355,43 +305,38 @@ class TestAccountDeleteApi:
return_value=None,
),
):
result = method(api)
result = method(api, user)
assert result["result"] == "success"
def test_delete_invalid_code(self, app: Flask):
api = AccountDeleteApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"token": "t", "code": "x"}
user = make_account()
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
return_value=False,
),
):
with pytest.raises(InvalidAccountDeletionCodeError):
method(api)
method(api, user)
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app: Flask):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"email": "a@test.com", "code": "x", "token": "t"}
user = make_account("acc-1")
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.workspace.account.current_account_with_tenant",
return_value=(MagicMock(id="acc-1"), "t1"),
),
patch.object(
type(console_ns),
"payload",
@ -412,13 +357,14 @@ class TestChangeEmailApis:
),
):
with pytest.raises(EmailCodeError):
method(api)
method(api, user)
def test_reset_email_already_used(self, app: Flask):
api = ChangeEmailResetApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"new_email": "x@test.com", "token": "t"}
user = make_account()
with (
app.test_request_context("/", json=payload),
@ -432,13 +378,13 @@ class TestChangeEmailApis:
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
):
with pytest.raises(EmailAlreadyInUseError):
method(api)
method(api, user)
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"email": "ok@test.com"}
@ -459,7 +405,7 @@ class TestCheckEmailUniqueApi:
def test_email_in_freeze(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"email": "x@test.com"}

View File

@ -1,4 +1,5 @@
from unittest.mock import MagicMock, patch
import inspect
from unittest.mock import patch
import pytest
from flask import Flask
@ -18,31 +19,10 @@ from controllers.console.workspace.endpoint import (
from core.plugin.impl.exc import PluginPermissionDeniedError
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
@pytest.fixture
def user_and_tenant():
return MagicMock(id="u1"), "t1"
@pytest.fixture
def patch_current_account(user_and_tenant):
with patch(
"controllers.console.workspace.endpoint.current_account_with_tenant",
return_value=user_and_tenant,
):
yield
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCollectionApi:
def test_create_success(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
@ -54,13 +34,13 @@ class TestEndpointCollectionApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is True
def test_create_permission_denied(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
@ -76,11 +56,11 @@ class TestEndpointCollectionApi:
),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
def test_create_validation_error(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"plugin_unique_identifier": "p1",
@ -92,14 +72,13 @@ class TestEndpointCollectionApi:
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointCreateApi:
def test_create_success(self, app: Flask):
api = DeprecatedEndpointCreateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"plugin_unique_identifier": "plugin-1",
@ -111,42 +90,40 @@ class TestDeprecatedEndpointCreateApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is True
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
):
result = method(api)
result = method(api, "t1", "u1")
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?page=0&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
@ -155,26 +132,25 @@ class TestEndpointListForSinglePluginApi:
return_value=[{"id": "e1"}],
),
):
result = method(api)
result = method(api, "t1", "u1")
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/?page=1&page_size=10"),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointItemApi:
def test_delete_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with (
app.test_request_context("/", method="DELETE"),
@ -183,26 +159,26 @@ class TestEndpointItemApi:
return_value=True,
) as mock_delete,
):
result = method(api, "e1")
result = method(api, "t1", "u1", "e1")
assert result["success"] is True
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
def test_delete_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
method = inspect.unwrap(api.delete)
with (
app.test_request_context("/", method="DELETE"),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api, "e1")
result = method(api, "t1", "u1", "e1")
assert result["success"] is False
def test_update_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
payload = {
"name": "new-name",
@ -216,7 +192,7 @@ class TestEndpointItemApi:
return_value=True,
) as mock_update,
):
result = method(api, "e1")
result = method(api, "t1", "u1", "e1")
assert result["success"] is True
mock_update.assert_called_once_with(
@ -229,7 +205,7 @@ class TestEndpointItemApi:
def test_update_validation_error(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
payload = {"settings": {}}
@ -237,11 +213,11 @@ class TestEndpointItemApi:
app.test_request_context("/", method="PATCH", json=payload),
):
with pytest.raises(ValueError):
method(api, "e1")
method(api, "t1", "u1", "e1")
def test_update_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
method = inspect.unwrap(api.patch)
payload = {
"name": "n",
@ -252,16 +228,15 @@ class TestEndpointItemApi:
app.test_request_context("/", method="PATCH", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api, "e1")
result = method(api, "t1", "u1", "e1")
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointDeleteApi:
def test_delete_success(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -269,23 +244,23 @@ class TestDeprecatedEndpointDeleteApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is True
def test_delete_invalid_payload(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
def test_delete_service_failure(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -293,16 +268,15 @@ class TestDeprecatedEndpointDeleteApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointUpdateApi:
def test_update_success(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"endpoint_id": "e1",
@ -314,13 +288,13 @@ class TestDeprecatedEndpointUpdateApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is True
def test_update_validation_error(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"endpoint_id": "e1", "settings": {}}
@ -328,11 +302,11 @@ class TestDeprecatedEndpointUpdateApi:
app.test_request_context("/", json=payload),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
def test_update_service_failure(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {
"endpoint_id": "e1",
@ -344,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is False
@ -379,11 +353,10 @@ class TestEndpointRouteMetadata:
assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",)
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -391,23 +364,23 @@ class TestEndpointEnableApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is True
def test_enable_invalid_payload(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")
def test_enable_service_failure(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -415,16 +388,15 @@ class TestEndpointEnableApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is False
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"endpoint_id": "e1"}
@ -432,16 +404,16 @@ class TestEndpointDisableApi:
app.test_request_context("/", json=payload),
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
):
result = method(api)
result = method(api, "t1", "u1")
assert result["success"] is True
def test_disable_invalid_payload(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
with (
app.test_request_context("/", json={}),
):
with pytest.raises(ValueError):
method(api)
method(api, "t1", "u1")

View File

@ -1,3 +1,4 @@
import inspect
from io import BytesIO
from unittest.mock import MagicMock, patch
@ -28,38 +29,47 @@ from controllers.console.workspace.workspace import (
)
from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.account import TenantStatus
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def make_account(account_id: str = "u1") -> Account:
account = Account(name="Test User", email=f"{account_id}@example.com")
account.id = account_id
return account
def make_tenant(
tenant_id: str = "t1",
*,
name: str | None = None,
status: TenantStatus = TenantStatus.NORMAL,
custom_config: TenantCustomConfigDict | None = None,
) -> Tenant:
tenant = Tenant(name=name or f"Tenant {tenant_id}", status=status)
tenant.id = tenant_id
tenant.created_at = naive_utc_now()
if custom_config is not None:
tenant.custom_config_dict = custom_config
return tenant
def make_account_with_tenant(tenant: Tenant) -> Account:
account = make_account()
account._current_tenant = tenant
return account
class TestTenantListApi:
def test_get_success_saas_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=naive_utc_now(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=naive_utc_now(),
)
tenant1 = make_tenant("t1", name="Tenant 1")
tenant2 = make_tenant("t2", name="Tenant 2")
user = make_account()
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
@ -76,7 +86,7 @@ class TestTenantListApi:
) as get_plan_bulk_mock,
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
result, status = method(api, "t1", user)
assert status == 200
assert len(result["workspaces"]) == 2
@ -93,30 +103,18 @@ class TestTenantListApi:
(SaaS contract treats enabled as on; display follows subscription.plan).
"""
api = TenantListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=naive_utc_now(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=naive_utc_now(),
)
tenant1 = make_tenant("t1", name="Tenant 1")
tenant2 = make_tenant("t2", name="Tenant 2")
features_t2 = MagicMock()
features_t2.billing.enabled = False
features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL
user = make_account()
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
@ -133,7 +131,7 @@ class TestTenantListApi:
return_value=features_t2,
) as get_features_mock,
):
result, status = method(api)
result, status = method(api, "t1", user)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
@ -148,30 +146,18 @@ class TestTenantListApi:
so we simulate the real failure mode by returning empty dict for non-empty input.
"""
api = TenantListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=naive_utc_now(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=naive_utc_now(),
)
tenant1 = make_tenant("t1", name="Tenant 1")
tenant2 = make_tenant("t2", name="Tenant 2")
features = MagicMock()
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.TEAM
user = make_account()
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
@ -189,7 +175,7 @@ class TestTenantListApi:
) as get_features_mock,
patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock,
):
result, status = method(api)
result, status = method(api, "t2", user)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
@ -200,25 +186,17 @@ class TestTenantListApi:
def test_get_billing_disabled_community_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant = MagicMock(
id="t1",
name="Tenant",
status="active",
created_at=naive_utc_now(),
)
tenant = make_tenant("t1", name="Tenant")
features = MagicMock()
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.SANDBOX
user = make_account()
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
@ -231,7 +209,7 @@ class TestTenantListApi:
return_value=features,
) as get_features_mock,
):
result, status = method(api)
result, status = method(api, "t1", user)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
@ -239,26 +217,14 @@ class TestTenantListApi:
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=naive_utc_now(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=naive_utc_now(),
)
tenant1 = make_tenant("t1", name="Tenant 1")
tenant2 = make_tenant("t2", name="Tenant 2")
user = make_account()
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
@ -268,7 +234,7 @@ class TestTenantListApi:
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
result, status = method(api, "t2", user)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
@ -279,13 +245,11 @@ class TestTenantListApi:
def test_get_enterprise_only_with_empty_tenants(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
user = make_account()
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[],
@ -295,7 +259,7 @@ class TestTenantListApi:
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
result, status = method(api, None, user)
assert status == 200
assert result["workspaces"] == []
@ -305,9 +269,9 @@ class TestTenantListApi:
class TestWorkspaceListApi:
def test_get_success(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now())
tenant = make_tenant("t1", name="T")
paginate_result = MagicMock(items=[tenant], has_next=False, total=1)
with (
@ -322,9 +286,9 @@ class TestWorkspaceListApi:
def test_get_has_next_true(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now())
tenant = make_tenant("t1", name="T")
paginate_result = MagicMock(items=[tenant], has_next=True, total=10)
with (
@ -340,80 +304,71 @@ class TestWorkspaceListApi:
class TestTenantApi:
def test_post_active_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
tenant = make_tenant()
user = make_account_with_tenant(tenant)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result, status = method(api)
result, status = method(api, user)
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
archived = MagicMock(status=TenantStatus.ARCHIVE)
new_tenant = MagicMock(status="active")
user = MagicMock(current_tenant=archived)
archived = make_tenant(status=TenantStatus.ARCHIVE)
new_tenant = make_tenant("new")
user = make_account_with_tenant(archived)
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
),
):
result, status = method(api)
result, status = method(api, user)
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
user = make_account_with_tenant(make_tenant(status=TenantStatus.ARCHIVE))
with (
app.test_request_context("/workspaces/current"),
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
):
with pytest.raises(Unauthorized):
method(api)
method(api, user)
def test_post_info_path(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
tenant = MagicMock(status="active")
user = MagicMock(current_tenant=tenant)
tenant = make_tenant()
user = make_account_with_tenant(tenant)
with (
app.test_request_context("/info"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(user, "t1"),
),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
return_value={"id": "t1"},
),
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
):
result, status = method(api)
result, status = method(api, user)
warn_mock.assert_called_once()
assert status == 200
@ -456,16 +411,14 @@ class TestTenantInfoResponse:
class TestSwitchWorkspaceApi:
def test_switch_success(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"tenant_id": "t2"}
tenant = MagicMock(id="t2")
tenant = make_tenant("t2")
user = make_account()
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
patch(
@ -473,85 +426,73 @@ class TestSwitchWorkspaceApi:
),
):
get_mock.return_value = tenant
result = method(api)
result = method(api, user)
assert result["result"] == "success"
def test_switch_not_linked(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"tenant_id": "bad"}
user = make_account()
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
):
with pytest.raises(AccountNotLinkTenantError):
method(api)
method(api, user)
def test_switch_tenant_not_found(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"tenant_id": "missing"}
user = make_account()
with (
app.test_request_context("/workspaces/switch", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
):
get_mock.return_value = None
with pytest.raises(ValueError):
method(api)
method(api, user)
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
tenant = MagicMock(custom_config_dict={})
tenant = make_tenant(custom_config={})
payload = {"remove_webapp_brand": True}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
),
):
result = method(api)
result = method(api, "t1")
assert result["result"] == "success"
def test_logo_fallback(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
tenant = make_tenant(custom_config={"replace_webapp_logo": "old-logo"})
payload = {"remove_webapp_brand": False}
with (
app.test_request_context("/workspaces/custom-config", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch(
"controllers.console.workspace.workspace.db.get_or_404",
return_value=tenant,
@ -562,7 +503,7 @@ class TestCustomConfigWorkspaceApi:
return_value={"id": "t1"},
),
):
result = method(api)
result = method(api, "t1")
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
assert result["result"] == "success"
@ -571,54 +512,41 @@ class TestCustomConfigWorkspaceApi:
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
user = make_account()
with (
app.test_request_context("/upload", data={}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with app.test_request_context("/upload", data={}):
with pytest.raises(NoFileUploadedError):
method(api)
method(api, user)
def test_too_many_files(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
data = {
"file": MagicMock(),
"extra": MagicMock(),
}
user = make_account()
with (
app.test_request_context("/upload", data=data),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
):
with app.test_request_context("/upload", data=data):
with pytest.raises(TooManyFilesError):
method(api)
method(api, user)
def test_invalid_extension(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
file = MagicMock(filename="test.txt")
user = make_account()
with (
app.test_request_context("/upload", data={"file": file}),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
):
with app.test_request_context("/upload", data={"file": file}):
with pytest.raises(UnsupportedFileTypeError):
method(api)
method(api, user)
def test_upload_success(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
@ -627,6 +555,7 @@ class TestWebappLogoWorkspaceApi:
)
upload = MagicMock(id="file1")
user = make_account()
with (
app.test_request_context(
@ -634,53 +563,46 @@ class TestWebappLogoWorkspaceApi:
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
mock_db.engine = MagicMock()
fs.return_value.upload_file.return_value = upload
result, status = method(api)
result, status = method(api, user)
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"data"),
filename="",
content_type="image/png",
)
user = make_account()
with (
app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
with app.test_request_context(
"/upload",
data={"file": file},
content_type="multipart/form-data",
):
with pytest.raises(FilenameNotExistsError):
method(api)
method(api, user)
def test_file_too_large(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
user = make_account()
with (
app.test_request_context(
@ -688,10 +610,6 @@ class TestWebappLogoWorkspaceApi:
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
@ -699,17 +617,18 @@ class TestWebappLogoWorkspaceApi:
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
with pytest.raises(FileTooLargeError):
method(api)
method(api, user)
def test_service_unsupported_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
file = FileStorage(
stream=BytesIO(b"x"),
filename="logo.png",
content_type="image/png",
)
user = make_account()
with (
app.test_request_context(
@ -717,10 +636,6 @@ class TestWebappLogoWorkspaceApi:
data={"file": file},
content_type="multipart/form-data",
),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), "t1"),
),
patch("controllers.console.workspace.workspace.FileService") as fs,
patch("controllers.console.workspace.workspace.db") as mock_db,
):
@ -728,23 +643,20 @@ class TestWebappLogoWorkspaceApi:
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
with pytest.raises(UnsupportedFileTypeError):
method(api)
method(api, user)
class TestWorkspaceInfoApi:
def test_post_success(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
tenant = MagicMock()
tenant = make_tenant()
payload = {"name": "New Name"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
patch("controllers.console.workspace.workspace.db.session.commit"),
patch(
@ -752,31 +664,27 @@ class TestWorkspaceInfoApi:
return_value={"name": "New Name"},
),
):
result = method(api)
result = method(api, "t1")
assert result["result"] == "success"
def test_no_current_tenant(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
method = inspect.unwrap(api.post)
payload = {"name": "X"}
with (
app.test_request_context("/workspaces/info", json=payload),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with pytest.raises(ValueError):
method(api)
method(api, None)
class TestWorkspacePermissionApi:
def test_get_success(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
permission = MagicMock(
workspace_id="t1",
@ -786,29 +694,20 @@ class TestWorkspacePermissionApi:
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
return_value=permission,
),
):
result, status = method(api)
result, status = method(api, "t1")
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)
method = inspect.unwrap(api.get)
with (
app.test_request_context("/permission"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with app.test_request_context("/permission"):
with pytest.raises(ValueError):
method(api)
method(api, None)

View File

@ -82,11 +82,13 @@ This test suite follows a comprehensive testing strategy that covers:
================================================================================
"""
from collections.abc import Iterator
from types import SimpleNamespace
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from flask import Flask
from flask import Flask, g
from flask.testing import FlaskClient
from flask_restx import Api
@ -95,6 +97,7 @@ from controllers.console.datasets.external import (
ExternalApiTemplateListApi,
)
from controllers.console.datasets.hit_testing import HitTestingApi
from models.account import Account, AccountStatus, TenantAccountRole
from models.dataset import Dataset, DatasetPermissionEnum
# ============================================================================
@ -817,15 +820,32 @@ class TestExternalDatasetApi:
)
@pytest.fixture
def mock_current_user(self):
"""Mock current user and tenant context."""
with patch("controllers.console.datasets.external.current_account_with_tenant") as mock_get_user:
mock_user = ControllerApiTestDataFactory.create_user_mock(is_dataset_editor=True)
def mock_current_account_context(self, app: Flask) -> Iterator[Mock]:
"""Provide the wrapper auth context required by HTTP-client controller tests."""
mock_user = Account(name="Test User", email="user-123@example.com")
mock_user.id = "user-123"
mock_user.status = AccountStatus.ACTIVE
mock_user.role = TenantAccountRole.EDITOR
def load_user_from_request_context() -> None:
g._login_user = mock_user
setattr( # noqa: B010
app,
"login_manager",
SimpleNamespace(load_user_from_request_context=load_user_from_request_context),
)
with (
patch("controllers.console.wraps.current_account_with_tenant") as mock_get_user,
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
patch("libs.login.check_csrf_token", return_value=None),
):
mock_tenant_id = "tenant-123"
mock_get_user.return_value = (mock_user, mock_tenant_id)
yield mock_get_user
def test_get_external_knowledge_apis_success(self, client_list, mock_current_user):
def test_get_external_knowledge_apis_success(self, client_list: FlaskClient, mock_current_account_context: Mock):
"""
Test successful retrieval of external knowledge API list.
@ -839,7 +859,11 @@ class TestExternalDatasetApi:
- Status code is 200
"""
# Arrange
apis = [{"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"} for i in range(3)]
apis = []
for i in range(3):
api_item = Mock()
api_item.to_dict.return_value = {"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"}
apis.append(api_item)
with patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis"
@ -847,6 +871,7 @@ class TestExternalDatasetApi:
mock_get_apis.return_value = (apis, 3)
# Act
# TODO: this should be made integrated tests...
response = client_list.get("/datasets/external-knowledge-api?page=1&limit=20")
# Assert