mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:13:59 +08:00
refactor(api): migrate tenant/user via DI for several endpoints (#37026)
This commit is contained in:
parent
5b5a06136a
commit
b67c3a5f76
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Concatenate, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -214,7 +214,9 @@ workflow_draft_variable_list_model = console_ns.model(
|
||||
)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -231,8 +233,8 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -24,14 +24,14 @@ from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DataSourceOauthBinding, Document
|
||||
from libs.login import login_required
|
||||
from models import Account, DataSourceOauthBinding, Document
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
from .. import console_ns
|
||||
from ..wraps import account_initialization_required, setup_required
|
||||
from ..wraps import account_initialization_required, setup_required, with_current_tenant_id, with_current_user
|
||||
|
||||
|
||||
class NotionEstimatePayload(BaseModel):
|
||||
@ -130,9 +130,8 @@ class DataSourceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[DataSourceIntegrateListResponse.__name__])
|
||||
def get(self) -> tuple[dict[str, Any], int]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.scalars(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@ -180,8 +179,10 @@ class DataSourceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, binding_id: UUID, action: Literal["enable", "disable"]) -> tuple[dict[str, str], int]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, binding_id: UUID, action: Literal["enable", "disable"]
|
||||
) -> tuple[dict[str, str], int]:
|
||||
binding_id_str = str(binding_id)
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
data_source_binding = session.execute(
|
||||
@ -220,9 +221,9 @@ class DataSourceNotionListApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionListQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[NotionIntegrateInfoListResponse.__name__])
|
||||
def get(self) -> tuple[dict[str, Any], int]:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, Any], int]:
|
||||
query = DataSourceNotionListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
credential = datasource_provider_service.get_datasource_credentials(
|
||||
@ -311,9 +312,8 @@ class DataSourceNotionPreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.doc(params=query_params_from_model(DataSourceNotionPreviewQuery))
|
||||
@console_ns.response(200, "Success", console_ns.models[TextContentResponse.__name__])
|
||||
def get(self, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, page_id: UUID, page_type: str) -> tuple[dict[str, str], int]:
|
||||
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -347,9 +347,8 @@ class DataSourceNotionIndexingEstimateApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[NotionEstimatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[IndexingEstimate.__name__])
|
||||
def post(self) -> tuple[dict[str, Any], int]:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str) -> tuple[dict[str, Any], int]:
|
||||
payload = NotionEstimatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump()
|
||||
# validate args
|
||||
|
||||
@ -44,8 +44,8 @@ from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from libs.login import login_required
|
||||
from models import Account, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.enums import IndexingStatus, SegmentStatus
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
@ -71,6 +71,8 @@ from ..wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -169,8 +171,9 @@ register_response_schema_models(
|
||||
|
||||
|
||||
class DocumentResource(Resource):
|
||||
def get_document(self, dataset_id: str, document_id: str) -> Document:
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
def get_document(
|
||||
self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str
|
||||
) -> Document:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -190,8 +193,7 @@ class DocumentResource(Resource):
|
||||
|
||||
return document
|
||||
|
||||
def get_batch_documents(self, dataset_id: str, batch: str) -> Sequence[Document]:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]:
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
@ -218,8 +220,8 @@ class GetProcessRuleApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
req_data = request.args
|
||||
|
||||
document_id = req_data.get("document_id")
|
||||
@ -279,8 +281,9 @@ class DatasetDocumentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
raw_args = request.args.to_dict()
|
||||
param = DocumentDatasetListParam.model_validate(raw_args)
|
||||
@ -405,8 +408,8 @@ class DatasetDocumentListApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[KnowledgeConfig.__name__])
|
||||
@console_ns.response(200, "Documents created successfully", console_ns.models[DatasetAndDocumentResponse.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -480,9 +483,10 @@ class DatasetInitApi(Resource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
@ -539,11 +543,12 @@ class DocumentIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}:
|
||||
raise DocumentAlreadyFinishedError()
|
||||
@ -604,10 +609,11 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
if not documents:
|
||||
return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
|
||||
data_process_rule = documents[0].dataset_process_rule
|
||||
@ -704,9 +710,10 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, batch: str):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, batch: str):
|
||||
dataset_id_str = str(dataset_id)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch)
|
||||
documents = self.get_batch_documents(dataset_id_str, batch, current_user)
|
||||
documents_status = []
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -759,16 +766,18 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
completed_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -777,7 +786,7 @@ class DocumentIndexingStatusApi(DocumentResource):
|
||||
total_segments = (
|
||||
db.session.scalar(
|
||||
select(func.count(DocumentSegment.id)).where(
|
||||
DocumentSegment.document_id == str(document_id_str),
|
||||
DocumentSegment.document_id == document_id_str,
|
||||
DocumentSegment.status != SegmentStatus.RE_SEGMENT,
|
||||
)
|
||||
)
|
||||
@ -820,10 +829,12 @@ class DocumentApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
metadata = request.args.get("metadata", "all")
|
||||
if metadata not in self.METADATA_CHOICES:
|
||||
@ -909,7 +920,9 @@ class DocumentApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(204, "Document deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -918,7 +931,7 @@ class DocumentApi(DocumentResource):
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
try:
|
||||
DocumentService.delete_document(document)
|
||||
@ -939,9 +952,11 @@ class DocumentDownloadApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id)
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@ -956,12 +971,13 @@ class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
@ -1003,11 +1019,19 @@ class DocumentProcessingApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["pause", "resume"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["pause", "resume"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
|
||||
if not current_user.is_dataset_editor:
|
||||
@ -1051,11 +1075,12 @@ class DocumentMetadataApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def put(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def put(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
document = self.get_document(dataset_id_str, document_id_str)
|
||||
document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id)
|
||||
|
||||
req_data = DocumentMetadataUpdatePayload.model_validate(request.get_json() or {})
|
||||
|
||||
@ -1100,8 +1125,10 @@ class DocumentStatusApi(DocumentResource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(
|
||||
self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"]
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
@ -1216,8 +1243,6 @@ class DocumentRetryApi(DocumentResource):
|
||||
raise NotFound("Dataset not found.")
|
||||
for document_id in payload.document_ids:
|
||||
try:
|
||||
document_id = str(document_id)
|
||||
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
|
||||
# 404 if document not found
|
||||
@ -1248,9 +1273,9 @@ class DocumentRenameApi(DocumentResource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Document renamed successfully", console_ns.models[DocumentResponse.__name__])
|
||||
@console_ns.expect(console_ns.models[DocumentRenamePayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -1273,9 +1298,9 @@ class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID):
|
||||
"""sync website document."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -1351,7 +1376,8 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id: UUID):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
@ -1359,7 +1385,6 @@ class DocumentGenerateSummaryApi(Resource):
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
@ -1444,7 +1469,8 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
@ -1457,7 +1483,6 @@ class DocumentSummaryStatusApi(DocumentResource):
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
|
||||
|
||||
@ -33,6 +33,8 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
@ -51,7 +53,8 @@ from fields.segment_fields import (
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from libs.helper import dump_response, escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
@ -164,9 +167,9 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
document_id_str = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -274,9 +277,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT)
|
||||
@console_ns.doc(params=query_params_from_model(SegmentIdListQuery))
|
||||
@console_ns.response(204, "Segments deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -312,9 +314,16 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
@cloud_edition_billing_resource_check("vector_space")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, action: Literal["enable", "disable"]):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
action: Literal["enable", "disable"],
|
||||
):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if not dataset:
|
||||
@ -373,9 +382,9 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentCreatePayload.__name__])
|
||||
@console_ns.response(200, "Segment created successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -431,9 +440,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[SegmentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Segment updated successfully", console_ns.models[SegmentDetailResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -500,9 +511,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_SEGMENT)
|
||||
@console_ns.response(204, "Segment deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -548,9 +561,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
@cloud_edition_billing_knowledge_limit_check("add_segment")
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[BatchImportPayload.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -619,9 +632,11 @@ class ChildChunkAddApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[ChildChunkCreatePayload.__name__])
|
||||
@console_ns.response(200, "Child chunk created successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
def post(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -677,9 +692,8 @@ class ChildChunkAddApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -731,9 +745,11 @@ class ChildChunkAddApi(Resource):
|
||||
console_ns.models[ChildChunkBatchUpdateResponse.__name__],
|
||||
)
|
||||
@console_ns.expect(console_ns.models[ChildChunkBatchUpdatePayload.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID, segment_id: UUID
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -781,9 +797,17 @@ class ChildChunkUpdateApi(Resource):
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.response(204, "Child chunk deleted successfully")
|
||||
def delete(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
@ -840,9 +864,17 @@ class ChildChunkUpdateApi(Resource):
|
||||
@console_ns.doc(params=SegmentDocParams.DATASET_DOCUMENT_CHILD_CHUNK)
|
||||
@console_ns.expect(console_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Child chunk updated successfully", console_ns.models[ChildChunkDetailResponse.__name__])
|
||||
def patch(self, dataset_id: UUID, document_id: UUID, segment_id: UUID, child_chunk_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(
|
||||
self,
|
||||
current_tenant_id: str,
|
||||
current_user: Account,
|
||||
dataset_id: UUID,
|
||||
document_id: UUID,
|
||||
segment_id: UUID,
|
||||
child_chunk_id: UUID,
|
||||
):
|
||||
# check dataset
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.console.wraps import (
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.dataset_fields import (
|
||||
dataset_detail_fields,
|
||||
@ -29,7 +30,8 @@ from fields.dataset_fields import (
|
||||
vector_setting_fields,
|
||||
weighted_score_fields,
|
||||
)
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
@ -152,8 +154,9 @@ class ExternalApiTemplateListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def post(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
ExternalDatasetService.validate_api_list(payload.settings)
|
||||
@ -182,8 +185,8 @@ class ExternalApiTemplateApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
external_knowledge_api = ExternalDatasetService.get_external_knowledge_api(
|
||||
external_knowledge_api_id_str, current_tenant_id
|
||||
@ -197,8 +200,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.expect(console_ns.models[ExternalKnowledgeApiPayload.__name__])
|
||||
def patch(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def patch(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
payload = ExternalKnowledgeApiPayload.model_validate(console_ns.payload or {})
|
||||
@ -217,8 +221,9 @@ class ExternalApiTemplateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(204, "External knowledge API deleted successfully")
|
||||
def delete(self, external_knowledge_api_id: UUID):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, current_user: Account, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_operator):
|
||||
@ -237,8 +242,8 @@ class ExternalApiUseCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, external_knowledge_api_id: UUID):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, external_knowledge_api_id: UUID):
|
||||
external_knowledge_api_id_str = str(external_knowledge_api_id)
|
||||
|
||||
external_knowledge_api_is_using, count = ExternalDatasetService.external_knowledge_api_use_check(
|
||||
@ -259,9 +264,10 @@ class ExternalDatasetCreateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, current_user: Account):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = ExternalDatasetCreatePayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
@ -293,8 +299,8 @@ class ExternalKnowledgeHitTestingApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, dataset_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, dataset_id: UUID):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
|
||||
@ -9,11 +9,18 @@ from configs import dify_config
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from core.plugin.impl.oauth import OAuthHandler
|
||||
from graphon.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
@ -66,11 +73,10 @@ class DatasourcePluginOAuthAuthorizationUrl(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, provider_id: str):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account, provider_id: str):
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
credential_id = request.args.get("credential_id")
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -174,9 +180,8 @@ class DatasourceAuth(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCredentialPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -195,10 +200,11 @@ class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider_id: str):
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, user: Account, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
datasources = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=current_tenant_id,
|
||||
@ -217,9 +223,8 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
plugin_id = datasource_provider_id.plugin_id
|
||||
provider_name = datasource_provider_id.provider_name
|
||||
@ -242,9 +247,8 @@ class DatasourceAuthUpdateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
payload = DatasourceCredentialUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -265,9 +269,8 @@ class DatasourceAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_all_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -278,9 +281,8 @@ class DatasourceHardCodeAuthListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_hard_code_datasource_credentials(tenant_id=current_tenant_id)
|
||||
return {"result": jsonable_encoder(datasources)}, 200
|
||||
@ -293,9 +295,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceCustomClientPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -311,9 +312,8 @@ class DatasourceAuthOauthCustomClient(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def delete(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, provider_id: str):
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_oauth_custom_client_params(
|
||||
@ -331,9 +331,8 @@ class DatasourceAuthDefaultApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceDefaultPayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
@ -353,9 +352,8 @@ class DatasourceUpdateProviderNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self, provider_id: str):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str, provider_id: str):
|
||||
payload = DatasourceUpdateNamePayload.model_validate(console_ns.payload or {})
|
||||
datasource_provider_id = DatasourceProviderID(provider_id)
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, NoReturn
|
||||
from typing import Any, Concatenate, NoReturn
|
||||
from uuid import UUID
|
||||
|
||||
from flask import Response, request
|
||||
@ -57,7 +57,9 @@ class WorkflowDraftVariablePatchPayload(BaseModel):
|
||||
register_schema_models(console_ns, WorkflowDraftVariablePatchPayload)
|
||||
|
||||
|
||||
def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
def _api_prerequisite[T, **P, R](
|
||||
f: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]:
|
||||
"""Common prerequisites for all draft workflow variable APIs.
|
||||
|
||||
It ensures the following conditions are satisfied:
|
||||
@ -72,10 +74,10 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_rag_pipeline
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
if not isinstance(current_user, Account) or not current_user.has_edit_permission:
|
||||
raise Forbidden()
|
||||
return f(*args, **kwargs)
|
||||
return f(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@ -42,6 +42,8 @@ from controllers.console.wraps import (
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
@ -49,8 +51,8 @@ from fields.member_fields import Account as AccountResponse
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import EmailStr, dump_response, extract_remote_ip, timezone, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import AccountIntegrate, InvitationCode
|
||||
from libs.login import login_required
|
||||
from models import Account, AccountIntegrate, InvitationCode
|
||||
from models.account import AccountStatus, InvitationCodeStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import UploadFile
|
||||
@ -258,9 +260,8 @@ class AccountInitApi(Resource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
if account.status == "active":
|
||||
raise AccountAlreadyInitedError()
|
||||
|
||||
@ -306,8 +307,8 @@ class AccountProfileApi(Resource):
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account):
|
||||
return _serialize_account(current_user)
|
||||
|
||||
|
||||
@ -318,8 +319,8 @@ class AccountNameApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountNamePayload.model_validate(payload)
|
||||
updated_account = AccountService.update_account(current_user, name=args.name)
|
||||
@ -336,8 +337,9 @@ class AccountAvatarApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True))
|
||||
avatar = args.avatar
|
||||
|
||||
@ -362,8 +364,8 @@ class AccountAvatarApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountAvatarPayload.model_validate(payload)
|
||||
|
||||
@ -379,8 +381,8 @@ class AccountInterfaceLanguageApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceLanguagePayload.model_validate(payload)
|
||||
|
||||
@ -396,8 +398,8 @@ class AccountInterfaceThemeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountInterfaceThemePayload.model_validate(payload)
|
||||
|
||||
@ -413,8 +415,8 @@ class AccountTimezoneApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountTimezonePayload.model_validate(payload)
|
||||
|
||||
@ -430,8 +432,8 @@ class AccountPasswordApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountPasswordPayload.model_validate(payload)
|
||||
|
||||
@ -449,9 +451,8 @@ class AccountIntegrateApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountIntegrateListResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
account_integrates = db.session.scalars(
|
||||
select(AccountIntegrate).where(AccountIntegrate.account_id == account.id)
|
||||
).all()
|
||||
@ -495,9 +496,8 @@ class AccountDeleteVerifyApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultDataResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
token, code = AccountService.generate_account_deletion_verification_code(account)
|
||||
AccountService.send_account_deletion_verification_email(account, code)
|
||||
|
||||
@ -511,9 +511,8 @@ class AccountDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = AccountDeletePayload.model_validate(payload)
|
||||
|
||||
@ -547,9 +546,8 @@ class EducationVerifyApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationVerifyResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
return EducationVerifyResponse.model_validate(
|
||||
BillingService.EducationIdentity.verify(account.id, account.email) or {}
|
||||
).model_dump(mode="json")
|
||||
@ -563,9 +561,8 @@ class EducationApi(Resource):
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
def post(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def post(self, account: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = EducationActivatePayload.model_validate(payload)
|
||||
|
||||
@ -577,9 +574,8 @@ class EducationApi(Resource):
|
||||
@only_edition_cloud
|
||||
@cloud_edition_billing_enabled
|
||||
@console_ns.response(200, "Success", console_ns.models[EducationStatusResponse.__name__])
|
||||
def get(self):
|
||||
account, _ = current_account_with_tenant()
|
||||
|
||||
@with_current_user
|
||||
def get(self, account: Account):
|
||||
res = BillingService.EducationIdentity.status(account.id) or {}
|
||||
# convert expire_at to UTC timestamp from isoformat
|
||||
if res and "expire_at" in res:
|
||||
@ -613,8 +609,8 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailSendPayload.model_validate(payload)
|
||||
|
||||
@ -673,8 +669,8 @@ class ChangeEmailCheckApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailValidityPayload.model_validate(payload)
|
||||
|
||||
@ -720,7 +716,8 @@ class ChangeEmailResetApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[AccountResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = ChangeEmailResetPayload.model_validate(payload)
|
||||
normalized_new_email = args.new_email.lower()
|
||||
@ -731,7 +728,6 @@ class ChangeEmailResetApi(Resource):
|
||||
if not AccountService.check_email_unique(normalized_new_email):
|
||||
raise EmailAlreadyInUseError()
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
@ -14,10 +14,16 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user_id,
|
||||
)
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from services.plugin.endpoint_service import EndpointService
|
||||
|
||||
|
||||
@ -96,17 +102,15 @@ register_schema_models(
|
||||
)
|
||||
|
||||
|
||||
def _create_endpoint() -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the current workspace."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
def _create_endpoint(tenant_id: str, user_id: str) -> dict[str, bool]:
|
||||
"""Create a plugin endpoint for the injected workspace and user."""
|
||||
args = EndpointCreatePayload.model_validate(console_ns.payload)
|
||||
|
||||
try:
|
||||
return {
|
||||
"success": EndpointService.create_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_unique_identifier=args.plugin_unique_identifier,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -116,16 +120,14 @@ def _create_endpoint() -> dict[str, bool]:
|
||||
raise ValueError(e.description) from e
|
||||
|
||||
|
||||
def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _update_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Update a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
args = EndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.update_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
name=args.name,
|
||||
settings=args.settings,
|
||||
@ -133,14 +135,12 @@ def _update_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
}
|
||||
|
||||
|
||||
def _delete_endpoint(endpoint_id: str) -> dict[str, bool]:
|
||||
def _delete_endpoint(tenant_id: str, user_id: str, endpoint_id: str) -> dict[str, bool]:
|
||||
"""Delete a plugin endpoint identified by the canonical path parameter."""
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
return {
|
||||
"success": EndpointService.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
endpoint_id=endpoint_id,
|
||||
)
|
||||
}
|
||||
@ -163,8 +163,10 @@ class EndpointCollectionApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/create")
|
||||
@ -189,8 +191,10 @@ class DeprecatedEndpointCreateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_endpoint()
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
return _create_endpoint(tenant_id=tenant_id, user_id=user_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/list")
|
||||
@ -206,9 +210,9 @@ class EndpointListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -218,7 +222,7 @@ class EndpointListApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@ -239,9 +243,9 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def get(self, tenant_id: str, user_id: str):
|
||||
args = EndpointListForPluginQuery.model_validate(request.args.to_dict(flat=True))
|
||||
|
||||
page = args.page
|
||||
@ -252,7 +256,7 @@ class EndpointListForSinglePluginApi(Resource):
|
||||
{
|
||||
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
plugin_id=plugin_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@ -278,8 +282,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def delete(self, id: str):
|
||||
return _delete_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def delete(self, tenant_id: str, user_id: str, id: str):
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
@console_ns.doc("update_endpoint")
|
||||
@console_ns.doc(description="Update a plugin endpoint")
|
||||
@ -295,8 +301,10 @@ class EndpointItemApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def patch(self, id: str):
|
||||
return _update_endpoint(endpoint_id=id)
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def patch(self, tenant_id: str, user_id: str, id: str):
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/delete")
|
||||
@ -322,9 +330,11 @@ class DeprecatedEndpointDeleteApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
return _delete_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _delete_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/update")
|
||||
@ -350,9 +360,11 @@ class DeprecatedEndpointUpdateApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = LegacyEndpointUpdatePayload.model_validate(console_ns.payload)
|
||||
return _update_endpoint(endpoint_id=args.endpoint_id)
|
||||
return _update_endpoint(tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/endpoints/enable")
|
||||
@ -370,14 +382,14 @@ class EndpointEnableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.enable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -397,13 +409,13 @@ class EndpointDisableApi(Resource):
|
||||
@login_required
|
||||
@is_admin_or_owner_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
|
||||
@with_current_user_id
|
||||
@with_current_tenant_id
|
||||
def post(self, tenant_id: str, user_id: str):
|
||||
args = EndpointIdPayload.model_validate(console_ns.payload)
|
||||
|
||||
return {
|
||||
"success": EndpointService.disable_endpoint(
|
||||
tenant_id=tenant_id, user_id=user.id, endpoint_id=args.endpoint_id
|
||||
tenant_id=tenant_id, user_id=user_id, endpoint_id=args.endpoint_id
|
||||
)
|
||||
}
|
||||
|
||||
@ -25,13 +25,15 @@ from controllers.console.wraps import (
|
||||
cloud_edition_billing_resource_check,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from libs.login import login_required
|
||||
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -153,8 +155,9 @@ class TenantListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, current_user: Account):
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
@ -228,11 +231,11 @@ class TenantApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@console_ns.response(200, "Success", console_ns.models[TenantInfoResponse.__name__])
|
||||
def post(self):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
if request.path == "/info":
|
||||
logger.warning("Deprecated URL /info was used.")
|
||||
|
||||
current_user, _ = current_account_with_tenant()
|
||||
tenant = current_user.current_tenant
|
||||
if not tenant:
|
||||
raise ValueError("No current tenant")
|
||||
@ -256,8 +259,8 @@ class SwitchWorkspaceApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
payload = console_ns.payload or {}
|
||||
args = SwitchWorkspacePayload.model_validate(payload)
|
||||
|
||||
@ -281,8 +284,8 @@ class CustomConfigWorkspaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceCustomConfigPayload.model_validate(payload)
|
||||
tenant = db.get_or_404(Tenant, current_tenant_id)
|
||||
@ -308,8 +311,8 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_resource_check("workspace_custom")
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# check file
|
||||
if "file" not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
@ -349,8 +352,8 @@ class WorkspaceInfoApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
# Change workspace name
|
||||
def post(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def post(self, current_tenant_id: str):
|
||||
payload = console_ns.payload or {}
|
||||
args = WorkspaceInfoPayload.model_validate(payload)
|
||||
|
||||
@ -372,13 +375,12 @@ class WorkspacePermissionApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_enterprise
|
||||
def get(self):
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
"""
|
||||
Get workspace permission settings.
|
||||
Returns permission flags that control workspace features like member invitations and owner transfer.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Concatenate
|
||||
from typing import Any, Concatenate, overload
|
||||
|
||||
from flask import abort, request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
@ -37,9 +37,21 @@ ERROR_MSG_INVALID_ENCRYPTED_DATA = "Invalid encrypted data"
|
||||
ERROR_MSG_INVALID_ENCRYPTED_CODE = "Invalid encrypted code"
|
||||
|
||||
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def account_initialization_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def account_initialization_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def account_initialization_required[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: Any, **kwargs: Any) -> R:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain functions used in tests and utilities.
|
||||
# check account initialization
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.status == AccountStatus.UNINITIALIZED:
|
||||
@ -218,9 +230,21 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
return decorated
|
||||
|
||||
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@overload
|
||||
def setup_required[T, **P, R](
|
||||
view: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def setup_required[**P, R](view: Callable[P, R]) -> Callable[P, R]: ...
|
||||
|
||||
|
||||
def setup_required[R](view: Callable[..., R]) -> Callable[..., R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
def decorated(*args: Any, **kwargs: Any) -> R:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain functions used in tests and utilities.
|
||||
# check setup
|
||||
if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)):
|
||||
if os.environ.get("INIT_PASSWORD"):
|
||||
@ -552,7 +576,7 @@ def with_current_user_id[T, **P, R](
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, str(current_user.id), *args, **kwargs)
|
||||
return view(self, current_user.id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, cast, overload
|
||||
|
||||
from flask import Response, current_app, g, has_request_context, request
|
||||
from flask_login.config import EXEMPT_METHODS
|
||||
@ -48,7 +48,17 @@ def current_account_with_tenant() -> tuple[Account, str]:
|
||||
return user, user.current_tenant_id
|
||||
|
||||
|
||||
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@overload
|
||||
def login_required[T, **P, R](
|
||||
func: Callable[Concatenate[T, P], R],
|
||||
) -> Callable[Concatenate[T, P], R | Response]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]: ...
|
||||
|
||||
|
||||
def login_required[R](func: Callable[..., R]) -> Callable[..., R | Response]:
|
||||
"""
|
||||
If you decorate a view with this, it will ensure that the current user is
|
||||
logged in and authenticated before calling the actual view. (If they are
|
||||
@ -83,7 +93,9 @@ def login_required[**P, R](func: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
def decorated_view(*args: Any, **kwargs: Any) -> R | Response:
|
||||
# The overloads keep Resource methods method-aware for pyrefly while
|
||||
# preserving support for plain Flask view functions.
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Iterator
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
@ -19,31 +21,18 @@ from controllers.console.datasets.data_source import (
|
||||
DataSourceNotionPreviewApi,
|
||||
)
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from models import DataSourceOauthBinding
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
from models import Account, DataSourceOauthBinding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_ctx():
|
||||
return (MagicMock(id="u1"), "tenant-1")
|
||||
def current_user() -> Account:
|
||||
account = Account(name="Test User", email="u1@example.com")
|
||||
account.id = "u1"
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_tenant(tenant_ctx):
|
||||
with patch(
|
||||
"controllers.console.datasets.data_source.current_account_with_tenant",
|
||||
return_value=tenant_ctx,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_engine():
|
||||
def mock_engine() -> Iterator[None]:
|
||||
with patch.object(
|
||||
type(data_source.db),
|
||||
"engine",
|
||||
@ -55,12 +44,12 @@ def mock_engine():
|
||||
|
||||
class TestDataSourceApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
def test_get_success(self, app: Flask) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
binding = DataSourceOauthBinding(
|
||||
tenant_id="tenant-1",
|
||||
@ -93,7 +82,7 @@ class TestDataSourceApi:
|
||||
return_value=MagicMock(all=lambda: [binding]),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["data"][0] == {
|
||||
@ -120,9 +109,9 @@ class TestDataSourceApi:
|
||||
"link": "http://localhost/console/api/oauth/data-source/notion",
|
||||
}
|
||||
|
||||
def test_get_no_bindings(self, app: Flask, patch_tenant):
|
||||
def test_get_no_bindings(self, app: Flask) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -131,14 +120,14 @@ class TestDataSourceApi:
|
||||
return_value=MagicMock(all=lambda: []),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["data"] == []
|
||||
|
||||
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_patch_enable_binding(self, app: Flask, mock_engine: None) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
@ -152,14 +141,14 @@ class TestDataSourceApi:
|
||||
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "enable")
|
||||
response, status = method(api, "tenant-1", "b1", "enable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is False
|
||||
|
||||
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_patch_disable_binding(self, app: Flask, mock_engine: None) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
@ -173,14 +162,14 @@ class TestDataSourceApi:
|
||||
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
response, status = method(api, "b1", "disable")
|
||||
response, status = method(api, "tenant-1", "b1", "disable")
|
||||
|
||||
assert status == 200
|
||||
assert binding.disabled is True
|
||||
|
||||
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_patch_binding_not_found(self, app: Flask, mock_engine: None) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -191,11 +180,11 @@ class TestDataSourceApi:
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "b1", "enable")
|
||||
method(api, "tenant-1", "b1", "enable")
|
||||
|
||||
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_patch_enable_already_enabled(self, app: Flask, mock_engine: None) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=False)
|
||||
|
||||
@ -208,11 +197,11 @@ class TestDataSourceApi:
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "enable")
|
||||
method(api, "tenant-1", "b1", "enable")
|
||||
|
||||
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_patch_disable_already_disabled(self, app: Flask, mock_engine: None) -> None:
|
||||
api = DataSourceApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
binding = MagicMock(id="b1", disabled=True)
|
||||
|
||||
@ -225,17 +214,17 @@ class TestDataSourceApi:
|
||||
mock_session.execute.return_value.scalar_one_or_none.return_value = binding
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "b1", "disable")
|
||||
method(api, "tenant-1", "b1", "disable")
|
||||
|
||||
|
||||
class TestDataSourceNotionListApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_credential_not_found(self, app: Flask, patch_tenant):
|
||||
def test_get_credential_not_found(self, app: Flask, current_user: Account) -> None:
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=c1"),
|
||||
@ -245,11 +234,11 @@ class TestDataSourceNotionListApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_get_success_no_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None:
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
@ -284,13 +273,13 @@ class TestDataSourceNotionListApi:
|
||||
),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1", current_user)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None:
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
page = MagicMock(
|
||||
page_id="p1",
|
||||
@ -337,13 +326,13 @@ class TestDataSourceNotionListApi:
|
||||
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = [document]
|
||||
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1", current_user)
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
|
||||
def test_get_invalid_dataset_type(self, app: Flask, current_user: Account, mock_engine: None) -> None:
|
||||
api = DataSourceNotionListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
dataset = MagicMock(data_source_type="other_type")
|
||||
|
||||
@ -360,17 +349,17 @@ class TestDataSourceNotionListApi:
|
||||
patch("controllers.console.datasets.data_source.sessionmaker"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
|
||||
class TestDataSourceNotionPreviewApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_preview_success(self, app: Flask, patch_tenant):
|
||||
def test_get_preview_success(self, app: Flask) -> None:
|
||||
api = DataSourceNotionPreviewApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")])
|
||||
|
||||
@ -385,21 +374,22 @@ class TestDataSourceNotionPreviewApi:
|
||||
return_value=extractor,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "p1", "page")
|
||||
response, status = method(api, "tenant-1", "p1", "page")
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDataSourceNotionIndexingEstimateApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
|
||||
def test_post_indexing_estimate_success(self, app: Flask) -> None:
|
||||
api = DataSourceNotionIndexingEstimateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
empty_rules: dict[str, object] = {}
|
||||
payload: dict[str, object] = {
|
||||
"notion_info_list": [
|
||||
{
|
||||
"workspace_id": "w1",
|
||||
@ -407,7 +397,7 @@ class TestDataSourceNotionIndexingEstimateApi:
|
||||
"pages": [{"page_id": "p1", "type": "page"}],
|
||||
}
|
||||
],
|
||||
"process_rule": {"rules": {}},
|
||||
"process_rule": {"rules": empty_rules},
|
||||
"doc_form": IndexStructureType.PARAGRAPH_INDEX,
|
||||
"doc_language": "English",
|
||||
}
|
||||
@ -422,19 +412,19 @@ class TestDataSourceNotionIndexingEstimateApi:
|
||||
return_value=MagicMock(model_dump=lambda: {"total_pages": 1}),
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
|
||||
class TestDataSourceNotionDatasetSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
def test_get_success(self, app: Flask) -> None:
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -455,9 +445,9 @@ class TestDataSourceNotionDatasetSyncApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
|
||||
def test_get_dataset_not_found(self, app: Flask) -> None:
|
||||
api = DataSourceNotionDatasetSyncApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -472,12 +462,12 @@ class TestDataSourceNotionDatasetSyncApi:
|
||||
|
||||
class TestDataSourceNotionDocumentSyncApi:
|
||||
@pytest.fixture
|
||||
def app(self, flask_app_with_containers: Flask):
|
||||
def app(self, flask_app_with_containers: Flask) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
def test_get_success(self, app: Flask) -> None:
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -498,9 +488,9 @@ class TestDataSourceNotionDocumentSyncApi:
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_document_not_found(self, app: Flask, patch_tenant):
|
||||
def test_get_document_not_found(self, app: Flask) -> None:
|
||||
api = DataSourceNotionDocumentSyncApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -23,25 +24,15 @@ from services.datasource_provider_service import DatasourceProviderService
|
||||
from services.plugin.oauth_service import OAuthProxyService
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?credential_id=cred-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
@ -58,20 +49,17 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
response = method(api, "tenant-1", user, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_no_oauth_config(self, app: Flask):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
@ -79,20 +67,16 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", user, "notion")
|
||||
|
||||
def test_get_without_credential_id_sets_cookie(self, app: Flask):
|
||||
api = DatasourcePluginOAuthAuthorizationUrl()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_oauth_client",
|
||||
@ -109,7 +93,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
return_value={"url": "http://auth"},
|
||||
),
|
||||
):
|
||||
response = method(api, "notion")
|
||||
response = method(api, "tenant-1", user, "notion")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "context_id" in response.headers.get("Set-Cookie")
|
||||
@ -118,7 +102,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
|
||||
class TestDatasourceOAuthCallback:
|
||||
def test_callback_success_new_credential(self, app: Flask):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
@ -160,7 +144,7 @@ class TestDatasourceOAuthCallback:
|
||||
|
||||
def test_callback_missing_context(self, app: Flask):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
@ -168,7 +152,7 @@ class TestDatasourceOAuthCallback:
|
||||
|
||||
def test_callback_invalid_context(self, app: Flask):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?context_id=bad"),
|
||||
@ -183,7 +167,7 @@ class TestDatasourceOAuthCallback:
|
||||
|
||||
def test_callback_oauth_config_not_found(self, app: Flask):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
context = {"user_id": "u", "tenant_id": "t"}
|
||||
|
||||
@ -205,7 +189,7 @@ class TestDatasourceOAuthCallback:
|
||||
|
||||
def test_callback_reauthorize_existing_credential(self, app: Flask):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
@ -248,7 +232,7 @@ class TestDatasourceOAuthCallback:
|
||||
|
||||
def test_callback_context_id_from_cookie(self, app: Flask):
|
||||
api = DatasourceOAuthCallback()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
oauth_response = MagicMock()
|
||||
oauth_response.credentials = {"token": "abc"}
|
||||
@ -292,40 +276,32 @@ class TestDatasourceOAuthCallback:
|
||||
class TestDatasourceAuth:
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "val"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_invalid_credentials(self, app: Flask):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credentials": {"key": "bad"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"add_datasource_api_key_provider",
|
||||
@ -333,63 +309,53 @@ class TestDatasourceAuth:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", "notion")
|
||||
|
||||
def test_get_success(self, app: Flask):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", user, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"]
|
||||
|
||||
def test_post_missing_credentials(self, app: Flask):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", "notion")
|
||||
|
||||
def test_get_empty_list(self, app: Flask):
|
||||
api = DatasourceAuth()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = MagicMock(id="user-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"list_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", user, "notion")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
@ -398,136 +364,112 @@ class TestDatasourceAuth:
|
||||
class TestDatasourceAuthDeleteApi:
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_missing_credential_id(self, app: Flask):
|
||||
api = DatasourceAuthDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", "notion")
|
||||
|
||||
|
||||
class TestDatasourceAuthUpdateApi:
|
||||
def test_update_success(self, app: Flask):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {"k": "v"}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_credentials_none(self, app: Flask):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": None}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
|
||||
def test_update_name_only(self, app: Flask):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
_, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_update_with_empty_credentials_dict(self, app: Flask):
|
||||
api = DatasourceAuthUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "credentials": {}}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_credentials",
|
||||
return_value=None,
|
||||
) as update_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
_, status = method(api, "tenant-1", "notion")
|
||||
|
||||
update_mock.assert_called_once()
|
||||
assert status == 201
|
||||
@ -536,62 +478,50 @@ class TestDatasourceAuthUpdateApi:
|
||||
class TestDatasourceAuthListApi:
|
||||
def test_list_success(self, app: Flask):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_auth_list_empty(self, app: Flask):
|
||||
api = DatasourceAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_all_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
|
||||
def test_hardcode_list_empty(self, app: Flask):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == []
|
||||
@ -600,21 +530,17 @@ class TestDatasourceAuthListApi:
|
||||
class TestDatasourceHardCodeAuthListApi:
|
||||
def test_list_success(self, app: Flask):
|
||||
api = DatasourceHardCodeAuthListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"get_hard_code_datasource_credentials",
|
||||
return_value=[{"id": "1"}],
|
||||
),
|
||||
):
|
||||
response, status = method(api)
|
||||
response, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
@ -622,73 +548,61 @@ class TestDatasourceHardCodeAuthListApi:
|
||||
class TestDatasourceAuthOauthCustomClient:
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"client_params": {}, "enable_oauth_custom_client": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"remove_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_empty_payload(self, app: Flask):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
_, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_post_disabled_flag(self, app: Flask):
|
||||
api = DatasourceAuthOauthCustomClient()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"client_params": {"a": 1},
|
||||
@ -698,17 +612,13 @@ class TestDatasourceAuthOauthCustomClient:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"setup_oauth_custom_client_params",
|
||||
return_value=None,
|
||||
) as setup_mock,
|
||||
):
|
||||
_, status = method(api, "notion")
|
||||
_, status = method(api, "tenant-1", "notion")
|
||||
|
||||
setup_mock.assert_called_once()
|
||||
assert status == 200
|
||||
@ -717,72 +627,60 @@ class TestDatasourceAuthOauthCustomClient:
|
||||
class TestDatasourceAuthDefaultApi:
|
||||
def test_set_default_success(self, app: Flask):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"id": "cred-1"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"set_default_datasource_provider",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_default_missing_id(self, app: Flask):
|
||||
api = DatasourceAuthDefaultApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", "notion")
|
||||
|
||||
|
||||
class TestDatasourceUpdateProviderNameApi:
|
||||
def test_update_name_success(self, app: Flask):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"credential_id": "id", "name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch.object(
|
||||
DatasourceProviderService,
|
||||
"update_datasource_provider_name",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "notion")
|
||||
response, status = method(api, "tenant-1", "notion")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_update_name_too_long(self, app: Flask):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"credential_id": "id",
|
||||
@ -792,27 +690,19 @@ class TestDatasourceUpdateProviderNameApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", "notion")
|
||||
|
||||
def test_update_name_missing_credential_id(self, app: Flask):
|
||||
api = DatasourceUpdateProviderNameApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"name": "Valid"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "notion")
|
||||
method(api, "tenant-1", "notion")
|
||||
|
||||
@ -11,7 +11,7 @@ from flask import Flask
|
||||
|
||||
from controllers.console.datasets import data_source as module
|
||||
from controllers.console.datasets.data_source import DataSourceApi, DataSourceNotionListApi
|
||||
from models import DataSourceOauthBinding
|
||||
from models import Account, DataSourceOauthBinding
|
||||
|
||||
ControllerMethod = Callable[..., tuple[dict[str, object], int]]
|
||||
|
||||
@ -28,13 +28,13 @@ def flask_app() -> Flask:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_context() -> tuple[MagicMock, str]:
|
||||
return MagicMock(id="user-1"), "tenant-1"
|
||||
def current_user() -> Account:
|
||||
account = Account(name="Test User", email="user-1@example.com")
|
||||
account.id = "user-1"
|
||||
return account
|
||||
|
||||
|
||||
def test_get_data_source_integrates_serializes_orm_binding(
|
||||
flask_app: Flask, tenant_context: tuple[MagicMock, str]
|
||||
) -> None:
|
||||
def test_get_data_source_integrates_serializes_orm_binding(flask_app: Flask) -> None:
|
||||
binding = DataSourceOauthBinding(
|
||||
tenant_id="tenant-1",
|
||||
access_token="token",
|
||||
@ -61,10 +61,9 @@ def test_get_data_source_integrates_serializes_orm_binding(
|
||||
|
||||
with (
|
||||
flask_app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=tenant_context),
|
||||
patch.object(module.db.session, "scalars", return_value=MagicMock(all=lambda: [binding])),
|
||||
):
|
||||
response, status = unwrap(DataSourceApi().get)(DataSourceApi())
|
||||
response, status = unwrap(DataSourceApi().get)(DataSourceApi(), "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {
|
||||
@ -96,23 +95,18 @@ def test_get_data_source_integrates_serializes_orm_binding(
|
||||
}
|
||||
|
||||
|
||||
def test_get_data_source_integrates_preserves_empty_list_when_no_binding(
|
||||
flask_app: Flask, tenant_context: tuple[MagicMock, str]
|
||||
) -> None:
|
||||
def test_get_data_source_integrates_preserves_empty_list_when_no_binding(flask_app: Flask) -> None:
|
||||
with (
|
||||
flask_app.test_request_context("/"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=tenant_context),
|
||||
patch.object(module.db.session, "scalars", return_value=MagicMock(all=lambda: [])),
|
||||
):
|
||||
response, status = unwrap(DataSourceApi().get)(DataSourceApi())
|
||||
response, status = unwrap(DataSourceApi().get)(DataSourceApi(), "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"data": []}
|
||||
|
||||
|
||||
def test_notion_pre_import_pages_serializes_frontend_list_shape(
|
||||
flask_app: Flask, tenant_context: tuple[MagicMock, str]
|
||||
) -> None:
|
||||
def test_notion_pre_import_pages_serializes_frontend_list_shape(flask_app: Flask, current_user: Account) -> None:
|
||||
page = MagicMock(
|
||||
page_id="page-1",
|
||||
page_name="Page",
|
||||
@ -137,7 +131,6 @@ def test_notion_pre_import_pages_serializes_frontend_list_shape(
|
||||
|
||||
with (
|
||||
flask_app.test_request_context("/?credential_id=credential-1"),
|
||||
patch.object(module, "current_account_with_tenant", return_value=tenant_context),
|
||||
patch.object(
|
||||
module.DatasourceProviderService,
|
||||
"get_datasource_credentials",
|
||||
@ -147,7 +140,7 @@ def test_notion_pre_import_pages_serializes_frontend_list_shape(
|
||||
patch.object(module, "sessionmaker"),
|
||||
patch("core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", return_value=runtime),
|
||||
):
|
||||
response, status = unwrap(DataSourceNotionListApi().get)(DataSourceNotionListApi())
|
||||
response, status = unwrap(DataSourceNotionListApi().get)(DataSourceNotionListApi(), "tenant-1", current_user)
|
||||
|
||||
assert status == 200
|
||||
assert response == {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -39,12 +40,6 @@ from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def make_serializable_document(**overrides):
|
||||
attrs = {
|
||||
"id": "doc-1",
|
||||
@ -125,11 +120,7 @@ def tenant_ctx():
|
||||
|
||||
@pytest.fixture
|
||||
def patch_tenant(tenant_ctx):
|
||||
with patch(
|
||||
"controllers.console.datasets.datasets_document.current_account_with_tenant",
|
||||
return_value=tenant_ctx,
|
||||
):
|
||||
yield
|
||||
return tenant_ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -173,16 +164,18 @@ def patch_permission():
|
||||
class TestGetProcessRuleApi:
|
||||
def test_get_default_success(self, app: Flask, patch_tenant):
|
||||
api = GetProcessRuleApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, _ = patch_tenant
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(api)
|
||||
response = method(api, user)
|
||||
|
||||
assert "rules" in response
|
||||
|
||||
def test_get_with_document_dataset_not_found(self, app: Flask, patch_tenant):
|
||||
api = GetProcessRuleApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, _ = patch_tenant
|
||||
|
||||
document = MagicMock(dataset_id="ds-1")
|
||||
|
||||
@ -198,13 +191,14 @@ class TestGetProcessRuleApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestDatasetDocumentListApi:
|
||||
def test_get_with_fetch_true_counts_segments(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
doc = make_serializable_document()
|
||||
pagination = MagicMock(items=[doc], total=1)
|
||||
@ -224,7 +218,7 @@ class TestDatasetDocumentListApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
resp = method(api, "ds-1")
|
||||
resp = method(api, tenant_id, user, "ds-1")
|
||||
|
||||
assert resp["data"][0]["id"] == "doc-1"
|
||||
assert resp["data"][0]["completed_segments"] == 2
|
||||
@ -234,7 +228,8 @@ class TestDatasetDocumentListApi:
|
||||
self, app: Flask, patch_tenant, patch_dataset, patch_permission
|
||||
):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
pagination = MagicMock(items=[make_serializable_document()], total=1)
|
||||
|
||||
@ -253,13 +248,14 @@ class TestDatasetDocumentListApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
resp = method(api, "ds-1")
|
||||
resp = method(api, tenant_id, user, "ds-1")
|
||||
|
||||
assert resp["total"] == 1
|
||||
|
||||
def test_get_success(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
pagination = MagicMock(items=[make_serializable_document()], total=1)
|
||||
|
||||
@ -274,7 +270,7 @@ class TestDatasetDocumentListApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1")
|
||||
response = method(api, tenant_id, user, "ds-1")
|
||||
|
||||
assert response["total"] == 1
|
||||
assert response["data"][0]["id"] == "doc-1"
|
||||
@ -283,7 +279,8 @@ class TestDatasetDocumentListApi:
|
||||
|
||||
def test_post_success(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, _ = patch_tenant
|
||||
|
||||
payload = {"indexing_technique": "economy"}
|
||||
created_dataset = make_dataset()
|
||||
@ -306,7 +303,7 @@ class TestDatasetDocumentListApi:
|
||||
),
|
||||
patch("models.dataset.db.session.scalar", return_value=0),
|
||||
):
|
||||
response = method(api, "ds-1")
|
||||
response = method(api, user, "ds-1")
|
||||
|
||||
assert "documents" in response
|
||||
assert response["dataset"]["id"] == "ds-1"
|
||||
@ -318,28 +315,25 @@ class TestDatasetDocumentListApi:
|
||||
|
||||
def test_post_forbidden(self, app: Flask):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
patch.object(type(console_ns), "payload", {}),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds-1")
|
||||
method(api, user, "ds-1")
|
||||
|
||||
def test_get_with_fetch_true_and_invalid_fetch(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
pagination = MagicMock(items=[make_serializable_document()], total=1)
|
||||
|
||||
@ -354,13 +348,14 @@ class TestDatasetDocumentListApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1")
|
||||
response = method(api, tenant_id, user, "ds-1")
|
||||
|
||||
assert response["total"] == 1
|
||||
|
||||
def test_get_sort_hit_count(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
pagination = MagicMock(items=[], total=0)
|
||||
|
||||
@ -375,7 +370,7 @@ class TestDatasetDocumentListApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1")
|
||||
response = method(api, tenant_id, user, "ds-1")
|
||||
|
||||
assert response["total"] == 0
|
||||
|
||||
@ -383,7 +378,8 @@ class TestDatasetDocumentListApi:
|
||||
class TestDatasetInitApi:
|
||||
def test_post_success_serializes_created_dataset_and_documents(self, app: Flask, patch_tenant):
|
||||
api = DatasetInitApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
payload = {"indexing_technique": "economy"}
|
||||
created_dataset = make_dataset()
|
||||
@ -402,7 +398,7 @@ class TestDatasetInitApi:
|
||||
),
|
||||
patch("models.dataset.db.session.scalar", return_value=0),
|
||||
):
|
||||
response = method(api)
|
||||
response = method(api, tenant_id, user)
|
||||
|
||||
assert response["dataset"]["id"] == "ds-1"
|
||||
assert response["documents"][0]["id"] == "doc-init"
|
||||
@ -414,7 +410,8 @@ class TestDatasetInitApi:
|
||||
class TestDocumentApi:
|
||||
def test_get_success(self, app: Flask, patch_tenant):
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(dataset_process_rule=None)
|
||||
|
||||
@ -426,21 +423,23 @@ class TestDocumentApi:
|
||||
return_value={},
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_invalid_metadata(self, app: Flask, patch_tenant):
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with app.test_request_context("/?metadata=wrong"), patch.object(api, "get_document", return_value=MagicMock()):
|
||||
with pytest.raises(InvalidMetadataError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
def test_delete_success(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -454,13 +453,14 @@ class TestDocumentApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 204
|
||||
|
||||
def test_delete_indexing_error(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -475,13 +475,14 @@ class TestDocumentApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(DocumentIndexingError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
|
||||
class TestDocumentDownloadApi:
|
||||
def test_download_success(self, app: Flask, patch_tenant):
|
||||
api = DocumentDownloadApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock()
|
||||
|
||||
@ -493,7 +494,7 @@ class TestDocumentDownloadApi:
|
||||
return_value="url",
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1", "doc-1")
|
||||
response = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert response["url"] == "url"
|
||||
|
||||
@ -501,24 +502,21 @@ class TestDocumentDownloadApi:
|
||||
class TestDocumentProcessingApi:
|
||||
def test_processing_forbidden_when_not_editor(self, app: Flask):
|
||||
api = DocumentProcessingApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
user = MagicMock(is_dataset_editor=False)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.current_account_with_tenant",
|
||||
return_value=(user, "tenant"),
|
||||
),
|
||||
patch.object(api, "get_document", return_value=MagicMock()),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds-1", "doc-1", "pause")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "pause")
|
||||
|
||||
def test_resume_from_error_state(self, app: Flask, patch_tenant):
|
||||
api = DocumentProcessingApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True)
|
||||
|
||||
@ -530,13 +528,14 @@ class TestDocumentProcessingApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
_, status = method(api, "ds-1", "doc-1", "resume")
|
||||
_, status = method(api, tenant_id, user, "ds-1", "doc-1", "resume")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_resume_success(self, app: Flask, patch_tenant):
|
||||
api = DocumentProcessingApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True)
|
||||
|
||||
@ -548,13 +547,14 @@ class TestDocumentProcessingApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "resume")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1", "resume")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_pause_success(self, app: Flask, patch_tenant):
|
||||
api = DocumentProcessingApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(indexing_status="indexing")
|
||||
|
||||
@ -566,25 +566,27 @@ class TestDocumentProcessingApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "pause")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1", "pause")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_pause_invalid(self, app: Flask, patch_tenant):
|
||||
api = DocumentProcessingApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(indexing_status=IndexingStatus.COMPLETED)
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
|
||||
with pytest.raises(InvalidActionError):
|
||||
method(api, "ds-1", "doc-1", "pause")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1", "pause")
|
||||
|
||||
|
||||
class TestDocumentMetadataApi:
|
||||
def test_put_metadata_schema_filtering(self, app: Flask, patch_tenant):
|
||||
api = DocumentMetadataApi()
|
||||
method = unwrap(api.put)
|
||||
method = inspect.unwrap(api.put)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
doc = MagicMock()
|
||||
|
||||
@ -607,13 +609,14 @@ class TestDocumentMetadataApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert doc.doc_metadata == {"amount": 10}
|
||||
|
||||
def test_put_success(self, app: Flask, patch_tenant):
|
||||
api = DocumentMetadataApi()
|
||||
method = unwrap(api.put)
|
||||
method = inspect.unwrap(api.put)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock()
|
||||
|
||||
@ -631,21 +634,23 @@ class TestDocumentMetadataApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_put_invalid_payload(self, app: Flask, patch_tenant):
|
||||
api = DocumentMetadataApi()
|
||||
method = unwrap(api.put)
|
||||
method = inspect.unwrap(api.put)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with app.test_request_context("/", json={}), patch.object(api, "get_document", return_value=MagicMock()):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
def test_put_invalid_doc_type(self, app: Flask, patch_tenant):
|
||||
api = DocumentMetadataApi()
|
||||
method = unwrap(api.put)
|
||||
method = inspect.unwrap(api.put)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
payload = {"doc_type": "invalid", "doc_metadata": {}}
|
||||
|
||||
@ -658,13 +663,14 @@ class TestDocumentMetadataApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
|
||||
class TestDocumentStatusApi:
|
||||
def test_patch_success(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentStatusApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, _ = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/?document_id=doc-1"),
|
||||
@ -681,13 +687,14 @@ class TestDocumentStatusApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "enable")
|
||||
response, status = method(api, user, "ds-1", "enable")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_patch_invalid_action(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentStatusApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, _ = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/?document_id=doc-1"),
|
||||
@ -705,13 +712,13 @@ class TestDocumentStatusApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidActionError):
|
||||
method(api, "ds-1", "enable")
|
||||
method(api, user, "ds-1", "enable")
|
||||
|
||||
|
||||
class TestDocumentRetryApi:
|
||||
def test_retry_archived_document_skipped(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentRetryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"document_ids": ["doc-1"]}
|
||||
|
||||
@ -739,7 +746,7 @@ class TestDocumentRetryApi:
|
||||
|
||||
def test_retry_success(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentRetryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"document_ids": ["doc-1"]}
|
||||
|
||||
@ -768,7 +775,7 @@ class TestDocumentRetryApi:
|
||||
|
||||
def test_retry_skips_completed_document(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentRetryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"document_ids": ["doc-1"]}
|
||||
|
||||
@ -795,7 +802,7 @@ class TestDocumentRetryApi:
|
||||
class TestDocumentPipelineExecutionLogApi:
|
||||
def test_get_log_success(self, app: Flask, patch_tenant, patch_dataset):
|
||||
api = DocumentPipelineExecutionLogApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
log = MagicMock(
|
||||
datasource_info="{}",
|
||||
@ -823,7 +830,8 @@ class TestDocumentPipelineExecutionLogApi:
|
||||
class TestDocumentGenerateSummaryApi:
|
||||
def test_generate_summary_missing_documents(self, app: Flask, patch_tenant, patch_permission):
|
||||
api = DocumentGenerateSummaryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, _ = patch_tenant
|
||||
|
||||
dataset = MagicMock(
|
||||
indexing_technique="high_quality",
|
||||
@ -845,11 +853,12 @@ class TestDocumentGenerateSummaryApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
method(api, user, "ds-1")
|
||||
|
||||
def test_generate_not_enabled(self, app: Flask, patch_tenant, patch_permission):
|
||||
api = DocumentGenerateSummaryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, _ = patch_tenant
|
||||
|
||||
dataset = MagicMock(indexing_technique="high_quality", summary_index_setting={"enable": False})
|
||||
|
||||
@ -864,11 +873,12 @@ class TestDocumentGenerateSummaryApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1")
|
||||
method(api, user, "ds-1")
|
||||
|
||||
def test_generate_summary_success_with_qa_skip(self, app: Flask, patch_tenant, patch_permission):
|
||||
api = DocumentGenerateSummaryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, _ = patch_tenant
|
||||
|
||||
dataset = MagicMock(
|
||||
indexing_technique="high_quality",
|
||||
@ -896,7 +906,7 @@ class TestDocumentGenerateSummaryApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1")
|
||||
response, status = method(api, user, "ds-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
@ -904,7 +914,8 @@ class TestDocumentGenerateSummaryApi:
|
||||
class TestDocumentSummaryStatusApi:
|
||||
def test_get_success(self, app: Flask, patch_tenant, patch_permission):
|
||||
api = DocumentSummaryStatusApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, _ = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -917,7 +928,7 @@ class TestDocumentSummaryStatusApi:
|
||||
return_value={"total_segments": 0},
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
@ -925,7 +936,8 @@ class TestDocumentSummaryStatusApi:
|
||||
class TestDocumentIndexingEstimateApi:
|
||||
def test_indexing_estimate_file_not_found(self, app: Flask, patch_tenant):
|
||||
api = DocumentIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
@ -945,11 +957,12 @@ class TestDocumentIndexingEstimateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
def test_indexing_estimate_generic_exception(self, app: Flask, patch_tenant):
|
||||
api = DocumentIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
@ -982,36 +995,38 @@ class TestDocumentIndexingEstimateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(IndexingEstimateError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
def test_get_finished(self, app: Flask, patch_tenant):
|
||||
api = DocumentIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(indexing_status=IndexingStatus.COMPLETED)
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
|
||||
with pytest.raises(DocumentAlreadyFinishedError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
|
||||
class TestDocumentBatchDownloadZipApi:
|
||||
def test_post_no_documents(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchDownloadZipApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
payload: dict[str, list[str]] = {"document_ids": []}
|
||||
|
||||
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1")
|
||||
method(api, tenant_id, user, "ds-1")
|
||||
|
||||
|
||||
class TestDatasetDocumentListApiDelete:
|
||||
def test_delete_success(self, app: Flask, patch_tenant, patch_dataset):
|
||||
"""Test successful deletion of documents"""
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?document_id=doc-1&document_id=doc-2"),
|
||||
@ -1031,7 +1046,7 @@ class TestDatasetDocumentListApiDelete:
|
||||
def test_delete_indexing_error(self, app: Flask, patch_tenant, patch_dataset):
|
||||
"""Test deletion with indexing error"""
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?document_id=doc-1"),
|
||||
@ -1050,7 +1065,7 @@ class TestDatasetDocumentListApiDelete:
|
||||
def test_delete_dataset_not_found(self, app: Flask, patch_tenant):
|
||||
"""Test deletion when dataset not found"""
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?document_id=doc-1"),
|
||||
@ -1066,7 +1081,8 @@ class TestDatasetDocumentListApiDelete:
|
||||
class TestDocumentBatchIndexingEstimateApi:
|
||||
def test_batch_indexing_estimate_website(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
doc = MagicMock(
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
@ -1089,13 +1105,14 @@ class TestDocumentBatchIndexingEstimateApi:
|
||||
return_value=MagicMock(model_dump=lambda: {"tokens": 2}),
|
||||
),
|
||||
):
|
||||
resp, status = method(api, "ds-1", "batch-1")
|
||||
resp, status = method(api, tenant_id, user, "ds-1", "batch-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_batch_indexing_estimate_notion(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
doc = MagicMock(
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
@ -1117,13 +1134,14 @@ class TestDocumentBatchIndexingEstimateApi:
|
||||
return_value=MagicMock(model_dump=lambda: {"tokens": 1}),
|
||||
),
|
||||
):
|
||||
resp, status = method(api, "ds-1", "batch-1")
|
||||
resp, status = method(api, tenant_id, user, "ds-1", "batch-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_batch_estimate_unsupported_datasource(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
@ -1134,22 +1152,24 @@ class TestDocumentBatchIndexingEstimateApi:
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1", "batch-1")
|
||||
method(api, tenant_id, user, "ds-1", "batch-1")
|
||||
|
||||
def test_get_batch_estimate_invalid_batch(self, app: Flask, patch_tenant):
|
||||
"""Test batch estimation with invalid batch"""
|
||||
api = DocumentBatchIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "invalid-batch")
|
||||
method(api, tenant_id, user, "ds-1", "invalid-batch")
|
||||
|
||||
|
||||
class TestDocumentBatchIndexingStatusApi:
|
||||
def test_get_batch_status_success_serializes_status_shape(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchIndexingStatusApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, _ = patch_tenant
|
||||
|
||||
document = MagicMock(
|
||||
id="doc-1",
|
||||
@ -1173,7 +1193,7 @@ class TestDocumentBatchIndexingStatusApi:
|
||||
side_effect=[2, 3],
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1", "batch-1")
|
||||
response = method(api, user, "ds-1", "batch-1")
|
||||
|
||||
assert response == {
|
||||
"data": [
|
||||
@ -1197,17 +1217,19 @@ class TestDocumentBatchIndexingStatusApi:
|
||||
def test_get_batch_status_invalid_batch(self, app: Flask, patch_tenant):
|
||||
"""Test batch status with invalid batch"""
|
||||
api = DocumentBatchIndexingStatusApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, _ = patch_tenant
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "invalid-batch")
|
||||
method(api, user, "ds-1", "invalid-batch")
|
||||
|
||||
|
||||
class TestDocumentIndexingStatusApi:
|
||||
def test_get_status_success_serializes_status_shape(self, app: Flask, patch_tenant):
|
||||
api = DocumentIndexingStatusApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(
|
||||
id="doc-1",
|
||||
@ -1231,7 +1253,7 @@ class TestDocumentIndexingStatusApi:
|
||||
side_effect=[1, 4],
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1", "doc-1")
|
||||
response = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert response["id"] == "doc-1"
|
||||
assert response["indexing_status"] == "indexing"
|
||||
@ -1241,17 +1263,19 @@ class TestDocumentIndexingStatusApi:
|
||||
def test_get_status_document_not_found(self, app: Flask, patch_tenant):
|
||||
"""Test getting status for non-existent document"""
|
||||
api = DocumentIndexingStatusApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_document", side_effect=NotFound()):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "invalid-doc")
|
||||
method(api, tenant_id, user, "ds-1", "invalid-doc")
|
||||
|
||||
|
||||
class TestDocumentRenameApi:
|
||||
def test_post_success_serializes_document_shape(self, app: Flask, patch_tenant):
|
||||
api = DocumentRenameApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, _ = patch_tenant
|
||||
|
||||
payload = {"name": "Renamed Document"}
|
||||
renamed_document = make_document(id="doc-renamed", name="Renamed Document")
|
||||
@ -1273,7 +1297,7 @@ class TestDocumentRenameApi:
|
||||
),
|
||||
patch("models.dataset.db.session.scalar", return_value=0),
|
||||
):
|
||||
response = method(api, "ds-1", "doc-1")
|
||||
response = method(api, user, "ds-1", "doc-1")
|
||||
|
||||
assert response["id"] == "doc-renamed"
|
||||
assert response["name"] == "Renamed Document"
|
||||
@ -1286,7 +1310,8 @@ class TestDocumentApiMetadata:
|
||||
def test_get_with_only_option(self, app: Flask, patch_tenant):
|
||||
"""Test get with 'only' metadata option"""
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(dataset_process_rule=None, doc_metadata_details=[])
|
||||
|
||||
@ -1298,14 +1323,15 @@ class TestDocumentApiMetadata:
|
||||
return_value={},
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_with_without_option(self, app: Flask, patch_tenant):
|
||||
"""Test get with 'without' metadata option"""
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(dataset_process_rule=None)
|
||||
|
||||
@ -1317,7 +1343,7 @@ class TestDocumentApiMetadata:
|
||||
return_value={},
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
@ -1326,7 +1352,8 @@ class TestDocumentGenerateSummaryApiSuccess:
|
||||
def test_generate_not_enabled_high_quality(self, app: Flask, patch_tenant, patch_permission):
|
||||
"""Test summary generation on non-high-quality dataset"""
|
||||
api = DocumentGenerateSummaryApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user, _ = patch_tenant
|
||||
|
||||
dataset = MagicMock(indexing_technique="economy", summary_index_setting={"enable": True})
|
||||
|
||||
@ -1341,26 +1368,28 @@ class TestDocumentGenerateSummaryApiSuccess:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1")
|
||||
method(api, user, "ds-1")
|
||||
|
||||
|
||||
class TestDocumentProcessingApiResume:
|
||||
def test_resume_invalid_status(self, app: Flask, patch_tenant):
|
||||
"""Test resume on non-paused document"""
|
||||
api = DocumentProcessingApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False)
|
||||
|
||||
with app.test_request_context("/"), patch.object(api, "get_document", return_value=document):
|
||||
with pytest.raises(InvalidActionError):
|
||||
method(api, "ds-1", "doc-1", "resume")
|
||||
method(api, tenant_id, user, "ds-1", "doc-1", "resume")
|
||||
|
||||
|
||||
class TestDocumentPermissionCases:
|
||||
def test_document_batch_get_permission_denied(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -1374,11 +1403,12 @@ class TestDocumentPermissionCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds-1", "batch-1")
|
||||
method(api, tenant_id, user, "ds-1", "batch-1")
|
||||
|
||||
def test_document_batch_get_documents_not_found(self, app: Flask, patch_tenant):
|
||||
api = DocumentBatchIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -1392,7 +1422,7 @@ class TestDocumentPermissionCases:
|
||||
),
|
||||
patch.object(api, "get_batch_documents", return_value=None),
|
||||
):
|
||||
response, status = method(api, "ds-1", "batch-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "batch-1")
|
||||
|
||||
assert status == 200
|
||||
assert response == {
|
||||
@ -1405,7 +1435,7 @@ class TestDocumentPermissionCases:
|
||||
|
||||
def test_document_tenant_mismatch(self, app: Flask):
|
||||
api = DocumentApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
document = MagicMock(
|
||||
@ -1415,13 +1445,9 @@ class TestDocumentPermissionCases:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.DatasetService.get_dataset",
|
||||
return_value=MagicMock(), # ✅ prevents real DB call
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.DocumentService.get_document",
|
||||
@ -1433,11 +1459,12 @@ class TestDocumentPermissionCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_process_rule_get_by_document_success(self, app: Flask, patch_tenant):
|
||||
api = GetProcessRuleApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, _ = patch_tenant
|
||||
|
||||
document = MagicMock(dataset_id="ds-1")
|
||||
process_rule = MagicMock(mode="custom", rules_dict={"a": 1})
|
||||
@ -1461,7 +1488,7 @@ class TestDocumentPermissionCases:
|
||||
return_value=process_rule,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, user)
|
||||
|
||||
if isinstance(result, tuple):
|
||||
response, status = result
|
||||
@ -1473,16 +1500,13 @@ class TestDocumentPermissionCases:
|
||||
|
||||
def test_process_rule_permission_denied(self, app: Flask):
|
||||
api = GetProcessRuleApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
document = MagicMock(dataset_id="ds-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/?document_id=doc-1"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.current_account_with_tenant",
|
||||
return_value=(MagicMock(is_dataset_editor=True), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_document.db.get_or_404",
|
||||
return_value=document,
|
||||
@ -1497,14 +1521,15 @@ class TestDocumentPermissionCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestDocumentListAdvancedCases:
|
||||
def test_document_list_with_multiple_sort_options(self, app: Flask, patch_tenant, patch_dataset, patch_permission):
|
||||
"""Test document list with different sort options"""
|
||||
api = DatasetDocumentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
pagination = MagicMock(items=[make_serializable_document()], total=1)
|
||||
|
||||
@ -1519,14 +1544,15 @@ class TestDocumentListAdvancedCases:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response = method(api, "ds-1")
|
||||
response = method(api, tenant_id, user, "ds-1")
|
||||
|
||||
assert response["total"] == 1
|
||||
|
||||
def test_document_metadata_with_schema_validation(self, app: Flask, patch_tenant):
|
||||
"""Test document metadata update with schema validation"""
|
||||
api = DocumentMetadataApi()
|
||||
method = unwrap(api.put)
|
||||
method = inspect.unwrap(api.put)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
doc = MagicMock()
|
||||
payload = {
|
||||
@ -1548,7 +1574,7 @@ class TestDocumentListAdvancedCases:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
assert doc.doc_metadata == {"amount": 5000, "currency": "USD"}
|
||||
@ -1557,7 +1583,8 @@ class TestDocumentListAdvancedCases:
|
||||
class TestDocumentIndexingEdgeCases:
|
||||
def test_document_indexing_with_extraction_setting(self, app: Flask, patch_tenant):
|
||||
api = DocumentIndexingEstimateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user, tenant_id = patch_tenant
|
||||
|
||||
document = MagicMock(
|
||||
indexing_status=IndexingStatus.INDEXING,
|
||||
@ -1586,6 +1613,6 @@ class TestDocumentIndexingEdgeCases:
|
||||
return_value=MagicMock(model_dump=lambda: {"tokens": 5}),
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, tenant_id, user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
@ -8,11 +8,11 @@ upload-file documents, and rejects unsupported or missing file cases.
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
from collections import UserDict
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
@ -79,7 +79,7 @@ def _mock_document(
|
||||
upload_file_id: str | None,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a minimal document object used by the controller."""
|
||||
data_source_info_dict: dict[str, Any] | None = None
|
||||
data_source_info_dict: dict[str, object] | None = None
|
||||
if upload_file_id is not None:
|
||||
data_source_info_dict = {"upload_file_id": upload_file_id}
|
||||
else:
|
||||
@ -97,7 +97,6 @@ def _wire_common_success_mocks(
|
||||
*,
|
||||
module,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
current_tenant_id: str,
|
||||
document_tenant_id: str,
|
||||
data_source_type: str,
|
||||
upload_file_id: str | None,
|
||||
@ -107,9 +106,6 @@ def _wire_common_success_mocks(
|
||||
"""Patch controller dependencies to create a deterministic test environment."""
|
||||
import services.dataset_service as dataset_service_module
|
||||
|
||||
# Make `current_account_with_tenant()` return a known user + tenant id.
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (_mock_user(), current_tenant_id))
|
||||
|
||||
# Return a dataset object and allow permission checks to pass.
|
||||
monkeypatch.setattr(module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1"))
|
||||
monkeypatch.setattr(module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None)
|
||||
@ -124,9 +120,9 @@ def _wire_common_success_mocks(
|
||||
monkeypatch.setattr(module.DocumentService, "get_document", lambda *_args, **_kwargs: document)
|
||||
|
||||
# Mock UploadFile lookup via FileService batch helper.
|
||||
upload_files_by_id: dict[str, Any] = {}
|
||||
upload_files_by_id: dict[str, object] = {}
|
||||
if upload_file_exists and upload_file_id is not None:
|
||||
upload_files_by_id[str(upload_file_id)] = SimpleNamespace(id=str(upload_file_id))
|
||||
upload_files_by_id[upload_file_id] = SimpleNamespace(id=upload_file_id)
|
||||
monkeypatch.setattr(module.FileService, "get_upload_files_by_ids", lambda *_args, **_kwargs: upload_files_by_id)
|
||||
|
||||
# Mock signing helper so the returned URL is deterministic.
|
||||
@ -153,8 +149,6 @@ def test_batch_download_zip_returns_send_file(
|
||||
) -> None:
|
||||
"""Ensure batch ZIP download returns a zip attachment via `send_file`."""
|
||||
|
||||
# Arrange common permission mocks.
|
||||
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
|
||||
monkeypatch.setattr(
|
||||
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
|
||||
)
|
||||
@ -204,7 +198,8 @@ def test_batch_download_zip_returns_send_file(
|
||||
json={"document_ids": ["11111111-1111-1111-1111-111111111111", "22222222-2222-2222-2222-222222222222"]},
|
||||
):
|
||||
api = datasets_document_module.DocumentBatchDownloadZipApi()
|
||||
result = api.post(dataset_id="ds-1")
|
||||
method = inspect.unwrap(api.post)
|
||||
result = method(api, "tenant-123", _mock_user(), dataset_id="ds-1")
|
||||
|
||||
# Assert: we returned via send_file with correct mime type and attachment.
|
||||
assert result["_send_file_kwargs"]["mimetype"] == "application/zip"
|
||||
@ -222,7 +217,6 @@ def test_batch_download_zip_response_is_openable_zip(
|
||||
"""Ensure the real Flask `send_file` response body is a valid ZIP that can be opened."""
|
||||
|
||||
# Arrange: same controller mocks as the lightweight send_file test, but we keep the real `send_file`.
|
||||
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
|
||||
monkeypatch.setattr(
|
||||
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
|
||||
)
|
||||
@ -270,7 +264,8 @@ def test_batch_download_zip_response_is_openable_zip(
|
||||
json={"document_ids": ["33333333-3333-3333-3333-333333333333", "44444444-4444-4444-4444-444444444444"]},
|
||||
):
|
||||
api = datasets_document_module.DocumentBatchDownloadZipApi()
|
||||
response = api.post(dataset_id="ds-1")
|
||||
method = inspect.unwrap(api.post)
|
||||
response = method(api, "tenant-123", _mock_user(), dataset_id="ds-1")
|
||||
|
||||
# Assert: response body is a valid ZIP and contains the expected entries.
|
||||
response.direct_passthrough = False
|
||||
@ -288,7 +283,6 @@ def test_batch_download_zip_rejects_non_upload_file_document(
|
||||
) -> None:
|
||||
"""Ensure batch ZIP download rejects non upload-file documents."""
|
||||
|
||||
monkeypatch.setattr(datasets_document_module, "current_account_with_tenant", lambda: (_mock_user(), "tenant-123"))
|
||||
monkeypatch.setattr(
|
||||
datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")
|
||||
)
|
||||
@ -314,8 +308,9 @@ def test_batch_download_zip_rejects_non_upload_file_document(
|
||||
json={"document_ids": ["55555555-5555-5555-5555-555555555555"]},
|
||||
):
|
||||
api = datasets_document_module.DocumentBatchDownloadZipApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(NotFound):
|
||||
api.post(dataset_id="ds-1")
|
||||
method(api, "tenant-123", _mock_user(), dataset_id="ds-1")
|
||||
|
||||
|
||||
def test_document_download_returns_url_for_upload_file_document(
|
||||
@ -326,7 +321,6 @@ def test_document_download_returns_url_for_upload_file_document(
|
||||
_wire_common_success_mocks(
|
||||
module=datasets_document_module,
|
||||
monkeypatch=monkeypatch,
|
||||
current_tenant_id="tenant-123",
|
||||
document_tenant_id="tenant-123",
|
||||
data_source_type="upload_file",
|
||||
upload_file_id="file-123",
|
||||
@ -337,7 +331,8 @@ def test_document_download_returns_url_for_upload_file_document(
|
||||
# Build a request context then call the resource method directly.
|
||||
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
|
||||
api = datasets_document_module.DocumentDownloadApi()
|
||||
result = api.get(dataset_id="ds-1", document_id="doc-1")
|
||||
method = inspect.unwrap(api.get)
|
||||
result = method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
|
||||
|
||||
assert result == {"url": "https://example.com/signed"}
|
||||
|
||||
@ -350,7 +345,6 @@ def test_document_download_rejects_non_upload_file_document(
|
||||
_wire_common_success_mocks(
|
||||
module=datasets_document_module,
|
||||
monkeypatch=monkeypatch,
|
||||
current_tenant_id="tenant-123",
|
||||
document_tenant_id="tenant-123",
|
||||
data_source_type="website_crawl",
|
||||
upload_file_id="file-123",
|
||||
@ -360,8 +354,9 @@ def test_document_download_rejects_non_upload_file_document(
|
||||
|
||||
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
|
||||
api = datasets_document_module.DocumentDownloadApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
with pytest.raises(NotFound):
|
||||
api.get(dataset_id="ds-1", document_id="doc-1")
|
||||
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
|
||||
|
||||
|
||||
def test_document_download_rejects_missing_upload_file_id(
|
||||
@ -372,7 +367,6 @@ def test_document_download_rejects_missing_upload_file_id(
|
||||
_wire_common_success_mocks(
|
||||
module=datasets_document_module,
|
||||
monkeypatch=monkeypatch,
|
||||
current_tenant_id="tenant-123",
|
||||
document_tenant_id="tenant-123",
|
||||
data_source_type="upload_file",
|
||||
upload_file_id=None,
|
||||
@ -382,8 +376,9 @@ def test_document_download_rejects_missing_upload_file_id(
|
||||
|
||||
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
|
||||
api = datasets_document_module.DocumentDownloadApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
with pytest.raises(NotFound):
|
||||
api.get(dataset_id="ds-1", document_id="doc-1")
|
||||
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
|
||||
|
||||
|
||||
def test_document_download_rejects_when_upload_file_record_missing(
|
||||
@ -394,7 +389,6 @@ def test_document_download_rejects_when_upload_file_record_missing(
|
||||
_wire_common_success_mocks(
|
||||
module=datasets_document_module,
|
||||
monkeypatch=monkeypatch,
|
||||
current_tenant_id="tenant-123",
|
||||
document_tenant_id="tenant-123",
|
||||
data_source_type="upload_file",
|
||||
upload_file_id="file-123",
|
||||
@ -404,8 +398,9 @@ def test_document_download_rejects_when_upload_file_record_missing(
|
||||
|
||||
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
|
||||
api = datasets_document_module.DocumentDownloadApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
with pytest.raises(NotFound):
|
||||
api.get(dataset_id="ds-1", document_id="doc-1")
|
||||
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
|
||||
|
||||
|
||||
def test_document_download_rejects_tenant_mismatch(
|
||||
@ -416,7 +411,6 @@ def test_document_download_rejects_tenant_mismatch(
|
||||
_wire_common_success_mocks(
|
||||
module=datasets_document_module,
|
||||
monkeypatch=monkeypatch,
|
||||
current_tenant_id="tenant-123",
|
||||
document_tenant_id="tenant-999",
|
||||
data_source_type="upload_file",
|
||||
upload_file_id="file-123",
|
||||
@ -426,5 +420,6 @@ def test_document_download_rejects_tenant_mismatch(
|
||||
|
||||
with app.test_request_context("/datasets/ds-1/documents/doc-1/download", method="GET"):
|
||||
api = datasets_document_module.DocumentDownloadApi()
|
||||
method = inspect.unwrap(api.get)
|
||||
with pytest.raises(Forbidden):
|
||||
api.get(dataset_id="ds-1", document_id="doc-1")
|
||||
method(api, "tenant-123", _mock_user(), dataset_id="ds-1", document_id="doc-1")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -34,12 +35,6 @@ from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDelete
|
||||
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _segment():
|
||||
segment = DocumentSegment(
|
||||
tenant_id="tenant-1",
|
||||
@ -129,10 +124,11 @@ def test_segment_response_with_summary():
|
||||
class TestDatasetDocumentSegmentListApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
dataset = MagicMock()
|
||||
document = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
segment = _segment()
|
||||
|
||||
@ -143,10 +139,6 @@ class TestDatasetDocumentSegmentListApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -170,40 +162,34 @@ class TestDatasetDocumentSegmentListApi:
|
||||
patch("models.dataset.db.session.scalar", return_value=None),
|
||||
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
|
||||
def test_get_dataset_not_found(self, app: Flask):
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_get_permission_denied(self, app: Flask):
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
dataset = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -214,13 +200,13 @@ class TestDatasetDocumentSegmentListApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
|
||||
class TestDatasetDocumentSegmentApi:
|
||||
def test_patch_success(self, app: Flask):
|
||||
api = DatasetDocumentSegmentApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
@ -233,10 +219,6 @@ class TestDatasetDocumentSegmentApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?segment_id=s1&segment_id=s2"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -258,14 +240,14 @@ class TestDatasetDocumentSegmentApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "enable")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "enable")
|
||||
|
||||
assert status == 200
|
||||
assert response["result"] == "success"
|
||||
|
||||
def test_patch_document_indexing_in_progress(self, app: Flask):
|
||||
api = DatasetDocumentSegmentApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
@ -278,10 +260,6 @@ class TestDatasetDocumentSegmentApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -304,11 +282,11 @@ class TestDatasetDocumentSegmentApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidActionError):
|
||||
method(api, "ds-1", "doc-1", "disable")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "disable")
|
||||
|
||||
def test_patch_llm_bad_request(self, app: Flask):
|
||||
api = DatasetDocumentSegmentApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
@ -322,10 +300,6 @@ class TestDatasetDocumentSegmentApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?segment_id=s1"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -348,11 +322,11 @@ class TestDatasetDocumentSegmentApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, "ds-1", "doc-1", "enable")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "enable")
|
||||
|
||||
def test_patch_provider_token_not_init(self, app: Flask):
|
||||
api = DatasetDocumentSegmentApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
@ -366,10 +340,6 @@ class TestDatasetDocumentSegmentApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?segment_id=s1"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -392,13 +362,13 @@ class TestDatasetDocumentSegmentApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, "ds-1", "doc-1", "enable")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "enable")
|
||||
|
||||
|
||||
class TestDatasetDocumentSegmentAddApi:
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DatasetDocumentSegmentAddApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"content": "hello"}
|
||||
|
||||
@ -416,10 +386,6 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -447,14 +413,14 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
patch("models.dataset.db.session.scalar", return_value=None),
|
||||
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["data"]["id"] == "seg-1"
|
||||
|
||||
def test_post_llm_bad_request(self, app: Flask):
|
||||
api = DatasetDocumentSegmentAddApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"content": "x"}
|
||||
|
||||
@ -471,10 +437,6 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -489,11 +451,11 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_post_provider_token_not_init(self, app: Flask):
|
||||
api = DatasetDocumentSegmentAddApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"content": "x"}
|
||||
|
||||
@ -510,10 +472,6 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -528,13 +486,13 @@ class TestDatasetDocumentSegmentAddApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
|
||||
class TestDatasetDocumentSegmentUpdateApi:
|
||||
def test_patch_success(self, app: Flask):
|
||||
api = DatasetDocumentSegmentUpdateApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
payload = {"content": "updated"}
|
||||
|
||||
@ -552,10 +510,6 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -586,14 +540,14 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
),
|
||||
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "seg-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
|
||||
|
||||
assert status == 200
|
||||
assert "data" in response
|
||||
|
||||
def test_patch_llm_bad_request(self, app: Flask):
|
||||
api = DatasetDocumentSegmentUpdateApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
payload = {"content": "x"}
|
||||
|
||||
@ -610,10 +564,6 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -632,26 +582,23 @@ class TestDatasetDocumentSegmentUpdateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method(api, "ds-1", "doc-1", "seg-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
|
||||
|
||||
|
||||
class TestDatasetDocumentSegmentBatchImportApi:
|
||||
def test_post_success(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"upload_file_id": "file-1"}
|
||||
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.name = "test.csv"
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -673,45 +620,39 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["job_status"] == "waiting"
|
||||
|
||||
def test_post_dataset_not_found(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"upload_file_id": "file-1"}
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_post_document_not_found(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"upload_file_id": "file-1"}
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -722,21 +663,18 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_post_upload_file_not_found(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"upload_file_id": "file-1"}
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -751,24 +689,21 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_post_invalid_file_type(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"upload_file_id": "file-1"}
|
||||
|
||||
upload_file = MagicMock()
|
||||
upload_file.name = "test.txt"
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -783,24 +718,21 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_post_async_task_failure(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"upload_file_id": "file-1"}
|
||||
|
||||
upload_file = MagicMock()
|
||||
upload_file.name = "test.csv"
|
||||
user = MagicMock(id="u1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -818,14 +750,14 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
side_effect=Exception("redis down"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 500
|
||||
assert "error" in response
|
||||
|
||||
def test_get_job_not_found_in_redis(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -840,23 +772,19 @@ class TestDatasetDocumentSegmentBatchImportApi:
|
||||
|
||||
class TestChildChunkAddApi:
|
||||
def test_patch_documents_batch_update_payload(self):
|
||||
api_doc = unwrap(ChildChunkAddApi.patch).__apidoc__
|
||||
api_doc = getattr(ChildChunkAddApi.patch, "__apidoc__") # noqa: B009
|
||||
expected_model = ChildChunkBatchUpdatePayload.__name__
|
||||
|
||||
assert [model.name for model in api_doc["expect"]] == [expected_model]
|
||||
|
||||
def test_get_uses_default_pagination_for_malformed_ints(self, app: Flask):
|
||||
api = ChildChunkAddApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
pagination = MagicMock(items=[], total=0, pages=0)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=bad&limit="),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -878,7 +806,7 @@ class TestChildChunkAddApi:
|
||||
return_value=pagination,
|
||||
) as get_child_chunks,
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "seg-1")
|
||||
response, status = method(api, "tenant-1", "ds-1", "doc-1", "seg-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["page"] == 1
|
||||
@ -887,7 +815,7 @@ class TestChildChunkAddApi:
|
||||
|
||||
def test_post_success(self, app: Flask):
|
||||
api = ChildChunkAddApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"content": "child"}
|
||||
|
||||
@ -904,10 +832,6 @@ class TestChildChunkAddApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -929,14 +853,14 @@ class TestChildChunkAddApi:
|
||||
return_value=child_chunk,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "seg-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
|
||||
|
||||
assert status == 200
|
||||
assert response["data"]["id"] == "cc-1"
|
||||
|
||||
def test_post_child_chunk_indexing_error(self, app: Flask):
|
||||
api = ChildChunkAddApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"content": "child"}
|
||||
|
||||
@ -949,10 +873,6 @@ class TestChildChunkAddApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -975,13 +895,13 @@ class TestChildChunkAddApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ChildChunkIndexingError):
|
||||
method(api, "ds-1", "doc-1", "seg-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1")
|
||||
|
||||
|
||||
class TestChildChunkUpdateApi:
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = ChildChunkUpdateApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
@ -993,10 +913,6 @@ class TestChildChunkUpdateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1018,14 +934,14 @@ class TestChildChunkUpdateApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1", "cc-1")
|
||||
|
||||
assert status == 204
|
||||
assert response == ""
|
||||
|
||||
def test_delete_child_chunk_index_error(self, app: Flask):
|
||||
api = ChildChunkUpdateApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
|
||||
@ -1036,10 +952,6 @@ class TestChildChunkUpdateApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1062,16 +974,17 @@ class TestChildChunkUpdateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ChildChunkDeleteIndexError):
|
||||
method(api, "ds-1", "doc-1", "seg-1", "cc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1", "seg-1", "cc-1")
|
||||
|
||||
|
||||
class TestSegmentListAdvancedCases:
|
||||
def test_segment_list_with_keyword_filter(self, app: Flask):
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
dataset = MagicMock()
|
||||
document = MagicMock()
|
||||
user = MagicMock()
|
||||
|
||||
segment = _segment()
|
||||
|
||||
@ -1079,10 +992,6 @@ class TestSegmentListAdvancedCases:
|
||||
|
||||
with (
|
||||
app.test_request_context("/?keyword=test"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1106,7 +1015,7 @@ class TestSegmentListAdvancedCases:
|
||||
patch("models.dataset.db.session.scalar", return_value=None),
|
||||
patch("models.dataset.db.session.execute", return_value=MagicMock(all=MagicMock(return_value=[]))),
|
||||
):
|
||||
result = method(api, "ds-1", "doc-1")
|
||||
result = method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
if isinstance(result, tuple):
|
||||
response, status = result
|
||||
@ -1118,18 +1027,15 @@ class TestSegmentListAdvancedCases:
|
||||
|
||||
def test_segment_list_postgres_keyword_filter_handles_scalar_keywords(self, app: Flask):
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
dataset = MagicMock()
|
||||
document = MagicMock()
|
||||
user = MagicMock()
|
||||
pagination = MagicMock(items=[], total=0, pages=0)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?keyword=test"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "11111111-1111-1111-1111-111111111111"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1151,7 +1057,13 @@ class TestSegmentListAdvancedCases:
|
||||
return_value=pagination,
|
||||
) as paginate_mock,
|
||||
):
|
||||
method(api, "22222222-2222-2222-2222-222222222222", "33333333-3333-3333-3333-333333333333")
|
||||
method(
|
||||
api,
|
||||
"11111111-1111-1111-1111-111111111111",
|
||||
user,
|
||||
"22222222-2222-2222-2222-222222222222",
|
||||
"33333333-3333-3333-3333-333333333333",
|
||||
)
|
||||
|
||||
query = paginate_mock.call_args.kwargs["select"]
|
||||
sql = str(query.compile(compile_kwargs={"literal_binds": True}))
|
||||
@ -1161,14 +1073,11 @@ class TestSegmentListAdvancedCases:
|
||||
def test_segment_list_permission_denied(self, app: Flask):
|
||||
"""Test segment list with permission denied"""
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=MagicMock(),
|
||||
@ -1179,33 +1088,30 @@ class TestSegmentListAdvancedCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_segment_list_dataset_not_found(self, app: Flask):
|
||||
"""Test segment list with dataset not found"""
|
||||
api = DatasetDocumentSegmentListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
|
||||
class TestSegmentOperationCases:
|
||||
def test_segment_add_with_provider_token_error(self, app: Flask):
|
||||
"""Test segment add with provider token not initialized"""
|
||||
api = DatasetDocumentSegmentAddApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
@ -1216,10 +1122,6 @@ class TestSegmentOperationCases:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1238,12 +1140,12 @@ class TestSegmentOperationCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProviderTokenNotInitError):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_batch_import_with_document_not_found(self, app: Flask):
|
||||
"""Test batch import with document not found"""
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
@ -1253,10 +1155,6 @@ class TestSegmentOperationCases:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1267,12 +1165,12 @@ class TestSegmentOperationCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_batch_import_with_invalid_file(self, app: Flask):
|
||||
"""Test batch import with invalid file type"""
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
@ -1284,10 +1182,6 @@ class TestSegmentOperationCases:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1302,11 +1196,11 @@ class TestSegmentOperationCases:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1", "doc-1")
|
||||
method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
def test_batch_import_with_async_task_failure(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
dataset = MagicMock()
|
||||
@ -1319,10 +1213,6 @@ class TestSegmentOperationCases:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.DatasetService.get_dataset",
|
||||
return_value=dataset,
|
||||
@ -1344,23 +1234,17 @@ class TestSegmentOperationCases:
|
||||
side_effect=Exception("Task failed"),
|
||||
),
|
||||
):
|
||||
response, status = method(api, "ds-1", "doc-1")
|
||||
response, status = method(api, "tenant-1", user, "ds-1", "doc-1")
|
||||
|
||||
assert status == 500
|
||||
assert "error" in response
|
||||
|
||||
def test_batch_import_get_job_not_found(self, app: Flask):
|
||||
api = DatasetDocumentSegmentBatchImportApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
user = MagicMock(is_dataset_editor=True)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?job_id=invalid-job"),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.current_account_with_tenant",
|
||||
return_value=(user, "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.datasets.datasets_segments.redis_client.get",
|
||||
return_value=None,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from importlib import import_module
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -16,50 +16,32 @@ from controllers.console.datasets.external import (
|
||||
ExternalDatasetCreateApi,
|
||||
ExternalKnowledgeHitTestingApi,
|
||||
)
|
||||
from models.account import Account, TenantAccountRole
|
||||
from services.dataset_service import DatasetService
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.hit_testing_service import HitTestingService
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
external_controller = import_module("controllers.console.datasets.external")
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
def app() -> Flask:
|
||||
app = Flask("test_external_dataset")
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_user():
|
||||
user = MagicMock()
|
||||
def current_user() -> Account:
|
||||
user = Account(name="Test User", email="user-1@example.com")
|
||||
user.id = "user-1"
|
||||
user.is_dataset_editor = True
|
||||
user.has_edit_permission = True
|
||||
user.is_dataset_operator = True
|
||||
user.role = TenantAccountRole.EDITOR
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth(monkeypatch: pytest.MonkeyPatch, current_user):
|
||||
monkeypatch.setattr(
|
||||
external_controller,
|
||||
"current_account_with_tenant",
|
||||
lambda: (current_user, "tenant-1"),
|
||||
)
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
api_item = MagicMock()
|
||||
api_item.to_dict.return_value = {"id": "1"}
|
||||
@ -79,10 +61,10 @@ class TestExternalApiTemplateListApi:
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
|
||||
def test_post_forbidden(self, app: Flask, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
def test_post_forbidden(self, app: Flask, current_user: Account):
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
@ -92,11 +74,11 @@ class TestExternalApiTemplateListApi:
|
||||
patch.object(ExternalDatasetService, "validate_api_list"),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
def test_post_duplicate_name(self, app: Flask):
|
||||
def test_post_duplicate_name(self, app: Flask, current_user: Account):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"name": "x", "settings": {"k": "v"}}
|
||||
|
||||
@ -111,13 +93,13 @@ class TestExternalApiTemplateListApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
|
||||
class TestExternalApiTemplateApi:
|
||||
def test_get_not_found(self, app: Flask):
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -128,24 +110,23 @@ class TestExternalApiTemplateApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "api-id")
|
||||
method(api, "tenant-1", "api-id")
|
||||
|
||||
def test_delete_forbidden(self, app: Flask, current_user):
|
||||
current_user.has_edit_permission = False
|
||||
current_user.is_dataset_operator = False
|
||||
def test_delete_forbidden(self, app: Flask, current_user: Account):
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
|
||||
api = ExternalApiTemplateApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "api-id")
|
||||
method(api, "tenant-1", current_user, "api-id")
|
||||
|
||||
|
||||
class TestExternalApiUseCheckApi:
|
||||
def test_get_scopes_usage_check_to_current_tenant(self, app: Flask):
|
||||
api = ExternalApiUseCheckApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -155,7 +136,7 @@ class TestExternalApiUseCheckApi:
|
||||
return_value=(True, 2),
|
||||
) as mock_use_check,
|
||||
):
|
||||
response, status = method(api, "api-id")
|
||||
response, status = method(api, "tenant-1", "api-id")
|
||||
|
||||
assert status == 200
|
||||
assert response == {"is_using": True, "count": 2}
|
||||
@ -163,9 +144,9 @@ class TestExternalApiUseCheckApi:
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApi:
|
||||
def test_create_success(self, app: Flask):
|
||||
def test_create_success(self, app: Flask, current_user: Account):
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
@ -203,14 +184,14 @@ class TestExternalDatasetCreateApi:
|
||||
return_value=dataset,
|
||||
),
|
||||
):
|
||||
_, status = method(api)
|
||||
_, status = method(api, "tenant-1", current_user)
|
||||
|
||||
assert status == 201
|
||||
|
||||
def test_create_forbidden(self, app: Flask, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
def test_create_forbidden(self, app: Flask, current_user: Account):
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api",
|
||||
@ -223,13 +204,13 @@ class TestExternalDatasetCreateApi:
|
||||
patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApi:
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask):
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, current_user: Account):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
@ -240,11 +221,11 @@ class TestExternalKnowledgeHitTestingApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "dataset-id")
|
||||
method(api, current_user, "dataset-id")
|
||||
|
||||
def test_hit_testing_success(self, app: Flask):
|
||||
def test_hit_testing_success(self, app: Flask, current_user: Account):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"query": "hello"}
|
||||
|
||||
@ -261,7 +242,7 @@ class TestExternalKnowledgeHitTestingApi:
|
||||
return_value={"ok": True},
|
||||
),
|
||||
):
|
||||
resp = method(api, "dataset-id")
|
||||
resp = method(api, current_user, "dataset-id")
|
||||
|
||||
assert resp["ok"] is True
|
||||
|
||||
@ -269,7 +250,7 @@ class TestExternalKnowledgeHitTestingApi:
|
||||
class TestBedrockRetrievalApi:
|
||||
def test_bedrock_retrieval(self, app: Flask):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
@ -293,9 +274,9 @@ class TestBedrockRetrievalApi:
|
||||
|
||||
|
||||
class TestExternalApiTemplateListApiAdvanced:
|
||||
def test_post_duplicate_name_error(self, app: Flask, mock_auth, current_user):
|
||||
def test_post_duplicate_name_error(self, app: Flask, current_user: Account):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"name": "duplicate_api", "settings": {"key": "value"}}
|
||||
|
||||
@ -309,11 +290,11 @@ class TestExternalApiTemplateListApiAdvanced:
|
||||
),
|
||||
):
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
def test_get_with_pagination(self, app: Flask, mock_auth, current_user):
|
||||
def test_get_with_pagination(self, app: Flask):
|
||||
api = ExternalApiTemplateListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
templates = [MagicMock(id=f"api-{i}") for i in range(3)]
|
||||
|
||||
@ -333,12 +314,12 @@ class TestExternalApiTemplateListApiAdvanced:
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
def test_create_forbidden(self, app: Flask, mock_auth, current_user):
|
||||
def test_create_forbidden(self, app: Flask, current_user: Account):
|
||||
"""Test creating external dataset without permission"""
|
||||
api = ExternalDatasetCreateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
current_user.is_dataset_editor = False
|
||||
current_user.role = TenantAccountRole.NORMAL
|
||||
|
||||
payload = {
|
||||
"external_knowledge_api_id": "api-1",
|
||||
@ -349,14 +330,14 @@ class TestExternalDatasetCreateApiAdvanced:
|
||||
|
||||
with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, "tenant-1", current_user)
|
||||
|
||||
|
||||
class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, mock_auth, current_user):
|
||||
def test_hit_testing_dataset_not_found(self, app: Flask, current_user: Account):
|
||||
"""Test hit testing on non-existent dataset"""
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"query": "test query",
|
||||
@ -372,11 +353,11 @@ class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api, "ds-1")
|
||||
method(api, current_user, "ds-1")
|
||||
|
||||
def test_hit_testing_with_custom_retrieval_model(self, app: Flask, mock_auth, current_user):
|
||||
def test_hit_testing_with_custom_retrieval_model(self, app: Flask, current_user: Account):
|
||||
api = ExternalKnowledgeHitTestingApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
dataset = MagicMock()
|
||||
payload = {
|
||||
@ -398,15 +379,15 @@ class TestExternalKnowledgeHitTestingApiAdvanced:
|
||||
return_value={"results": []},
|
||||
),
|
||||
):
|
||||
resp = method(api, "ds-1")
|
||||
resp = method(api, current_user, "ds-1")
|
||||
|
||||
assert resp["results"] == []
|
||||
|
||||
|
||||
class TestBedrockRetrievalApiAdvanced:
|
||||
def test_bedrock_retrieval_with_invalid_setting(self, app: Flask, mock_auth, current_user):
|
||||
def test_bedrock_retrieval_with_invalid_setting(self, app: Flask):
|
||||
api = BedrockRetrievalApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"retrieval_setting": {},
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
@ -11,7 +12,7 @@ from controllers.console.workspace.account import (
|
||||
ChangeEmailSendEmailApi,
|
||||
CheckEmailUnique,
|
||||
)
|
||||
from models import Account, AccountStatus
|
||||
from models import Account, AccountStatus, Tenant
|
||||
from services.account_service import AccountService
|
||||
from services.entities.auth_entities import (
|
||||
ChangeEmailNewEmailToken,
|
||||
@ -26,12 +27,16 @@ def app():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["RESTX_MASK_HEADER"] = "X-Fields"
|
||||
app.login_manager = SimpleNamespace(load_user_from_request_context=lambda: None)
|
||||
setattr(app, "login_manager", SimpleNamespace(load_user_from_request_context=lambda: None)) # noqa: B010
|
||||
return app
|
||||
|
||||
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: object | None = None) -> Account:
|
||||
tenant_obj = tenant if tenant is not None else SimpleNamespace(id="tenant-id")
|
||||
def _build_account(email: str, account_id: str = "acc", tenant: Tenant | None = None) -> Account:
|
||||
if tenant is None:
|
||||
tenant_obj = Tenant(name="Tenant")
|
||||
tenant_obj.id = "tenant-id"
|
||||
else:
|
||||
tenant_obj = tenant
|
||||
account = Account(name=account_id, email=email)
|
||||
account.email = email
|
||||
account.id = account_id
|
||||
@ -40,11 +45,6 @@ def _build_account(email: str, account_id: str = "acc", tenant: object | None =
|
||||
return account
|
||||
|
||||
|
||||
def _set_logged_in_user(account: Account):
|
||||
g._login_user = account
|
||||
g._current_tenant = account.current_tenant
|
||||
|
||||
|
||||
def _build_change_email_token(
|
||||
phase: str,
|
||||
*,
|
||||
@ -71,63 +71,44 @@ def _build_change_email_token(
|
||||
|
||||
|
||||
class TestChangeEmailSend:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_old_email_phase_when_request_email_does_not_match_current_user(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidEmailError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("current@example.com", "acc1"), None)
|
||||
current_user = _build_account("current@example.com", "acc1")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "other@example.com", "language": "en-US", "phase": "old_email"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
method = inspect.unwrap(ChangeEmailSendEmailApi().post)
|
||||
with pytest.raises(InvalidEmailError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
method(ChangeEmailSendEmailApi(), current_user)
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
account_id="acc1",
|
||||
@ -141,8 +122,9 @@ class TestChangeEmailSend:
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
api = ChangeEmailSendEmailApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
response = method(api, mock_account)
|
||||
|
||||
assert response == {"result": "success", "data": "token-abc"}
|
||||
mock_send_email.assert_called_once_with(
|
||||
@ -154,34 +136,23 @@ class TestChangeEmailSend:
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_new_email_phase_when_token_phase_is_not_old_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
account_id="acc1",
|
||||
@ -194,37 +165,28 @@ class TestChangeEmailSend:
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailSendEmailApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
method(api, mock_account)
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_new_email_phase_when_token_account_id_does_not_match_current_user(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
account_id="other-account",
|
||||
@ -237,41 +199,32 @@ class TestChangeEmailSend:
|
||||
method="POST",
|
||||
json={"email": "new@example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailSendEmailApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
method(api, mock_account)
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_validate_with_normalized_email(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
@ -286,8 +239,9 @@ class TestChangeEmailValidity:
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
api = ChangeEmailCheckApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
response = method(api, mock_account)
|
||||
|
||||
assert response == {"is_valid": True, "email": "user@example.com", "token": "new-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("user@example.com")
|
||||
@ -303,34 +257,24 @@ class TestChangeEmailValidity:
|
||||
mock_account,
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_upgrade_new_phase_token_to_new_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
current_user = _build_account("old@example.com", "acc")
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
@ -345,8 +289,9 @@ class TestChangeEmailValidity:
|
||||
method="POST",
|
||||
json={"email": "new@example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
api = ChangeEmailCheckApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
response = method(api, current_user)
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-verified-token"}
|
||||
mock_generate_token.assert_called_once_with(
|
||||
@ -356,37 +301,28 @@ class TestChangeEmailValidity:
|
||||
email="new@example.com",
|
||||
old_email="old@example.com",
|
||||
),
|
||||
mock_current_account.return_value[0],
|
||||
current_user,
|
||||
)
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_validity_when_token_is_already_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
current_user = _build_account("old@example.com", "acc")
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
@ -400,41 +336,33 @@ class TestChangeEmailValidity:
|
||||
method="POST",
|
||||
json={"email": "old@example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailCheckApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailCheckApi().post()
|
||||
method(api, current_user)
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_generate_token.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_validity_when_token_account_id_does_not_match_current_user(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
|
||||
current_user = _build_account("old@example.com", "acc")
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
@ -448,42 +376,33 @@ class TestChangeEmailValidity:
|
||||
method="POST",
|
||||
json={"email": "new@example.com", "code": "1234", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailCheckApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailCheckApi().post()
|
||||
method(api, current_user)
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_generate_token.assert_not_called()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_before_update(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
@ -500,46 +419,36 @@ class TestChangeEmailReset:
|
||||
method="POST",
|
||||
json={"new_email": "New@Example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
ChangeEmailResetApi().post()
|
||||
api = ChangeEmailResetApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
method(api, current_user)
|
||||
|
||||
mock_is_freeze.assert_called_once_with("new@example.com")
|
||||
mock_check_unique.assert_called_once_with("new@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_update_account.assert_called_once_with(current_user, email="new@example.com")
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_reset_when_token_phase_is_not_new_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
@ -554,44 +463,35 @@ class TestChangeEmailReset:
|
||||
method="POST",
|
||||
json={"new_email": "attacker@example.com", "token": "token-from-step1"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailResetApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
method(api, current_user)
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_reset_when_token_email_differs_from_payload_new_email(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
"""A verified token for address A must not be replayed to change to address B."""
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
@ -606,43 +506,34 @@ class TestChangeEmailReset:
|
||||
method="POST",
|
||||
json={"new_email": "attacker@example.com", "token": "token-verified"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailResetApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
method(api, current_user)
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_reset_when_token_account_id_does_not_match_current_user(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app: Flask,
|
||||
):
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = _build_change_email_token(
|
||||
@ -657,9 +548,10 @@ class TestChangeEmailReset:
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-verified"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
api = ChangeEmailResetApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
method(api, current_user)
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
@ -755,25 +647,25 @@ class TestAccountServiceGetChangeEmailData:
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
|
||||
def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask):
|
||||
def test_should_normalize_feedback_email(self, mock_update, app: Flask):
|
||||
with app.test_request_context(
|
||||
"/account/delete/feedback",
|
||||
method="POST",
|
||||
json={"email": "User@Example.com", "feedback": "test"},
|
||||
):
|
||||
response = AccountDeleteUpdateFeedbackApi().post()
|
||||
api = AccountDeleteUpdateFeedbackApi()
|
||||
method = inspect.unwrap(api.post)
|
||||
response = method(api)
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_update.assert_called_once_with("User@Example.com", "test")
|
||||
|
||||
|
||||
class TestCheckEmailUnique:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask):
|
||||
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, app: Flask):
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
|
||||
@ -782,7 +674,9 @@ class TestCheckEmailUnique:
|
||||
method="POST",
|
||||
json={"email": "Case@Test.com"},
|
||||
):
|
||||
response = CheckEmailUnique().post()
|
||||
api = CheckEmailUnique()
|
||||
method = inspect.unwrap(api.post)
|
||||
response = method(api)
|
||||
|
||||
assert response == {"result": "success"}
|
||||
mock_is_freeze.assert_called_once_with("case@test.com")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -31,22 +32,29 @@ from controllers.console.workspace.error import (
|
||||
CurrentPasswordIncorrectError,
|
||||
InvalidAccountDeletionCodeError,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus
|
||||
from models.enums import CreatorUserRole
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
def make_account(account_id: str = "u1", *, status: AccountStatus = AccountStatus.ACTIVE) -> Account:
|
||||
account = Account(name="John", email=f"{account_id}@test.com", status=status)
|
||||
account.id = account_id
|
||||
account.avatar = "avatar.png"
|
||||
account.interface_language = "en-US"
|
||||
account.interface_theme = "light"
|
||||
account.timezone = "UTC"
|
||||
account.last_login_ip = "127.0.0.1"
|
||||
return account
|
||||
|
||||
|
||||
class TestAccountInitApi:
|
||||
def test_init_success(self, app: Flask):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="inactive")
|
||||
account = make_account(status=AccountStatus.UNINITIALIZED)
|
||||
payload = {
|
||||
"interface_language": "en-US",
|
||||
"timezone": "UTC",
|
||||
@ -55,50 +63,35 @@ class TestAccountInitApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.commit", return_value=None),
|
||||
patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"),
|
||||
patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock,
|
||||
):
|
||||
scalar_mock.return_value = MagicMock(status="unused")
|
||||
resp = method(api)
|
||||
resp = method(api, account)
|
||||
|
||||
assert resp["result"] == "success"
|
||||
|
||||
def test_init_already_initialized(self, app: Flask):
|
||||
api = AccountInitApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
account = MagicMock(status="active")
|
||||
account = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/init"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
):
|
||||
with app.test_request_context("/account/init"):
|
||||
with pytest.raises(AccountAlreadyInitedError):
|
||||
method(api)
|
||||
method(api, account)
|
||||
|
||||
|
||||
class TestAccountProfileApi:
|
||||
def test_get_profile_success(self, app: Flask):
|
||||
api = AccountProfileApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/account/profile"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
):
|
||||
result = method(api)
|
||||
with app.test_request_context("/account/profile"):
|
||||
result = method(api, user)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
@ -116,24 +109,15 @@ class TestAccountUpdateApis:
|
||||
)
|
||||
def test_update_success(self, app: Flask, api_cls, payload):
|
||||
api = api_cls()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account", return_value=user),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, user)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
@ -143,10 +127,9 @@ class TestAccountAvatarApiGet:
|
||||
|
||||
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask):
|
||||
api = AccountAvatarApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "acc-owner"
|
||||
user = make_account("acc-owner")
|
||||
tenant_id = "tenant-1"
|
||||
file_id = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
@ -158,27 +141,22 @@ class TestAccountAvatarApiGet:
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/account/avatar?avatar={file_id}"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file),
|
||||
patch(
|
||||
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
|
||||
return_value="https://signed/example",
|
||||
) as sign_mock,
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, user)
|
||||
|
||||
assert result == {"avatar_url": "https://signed/example"}
|
||||
sign_mock.assert_called_once_with(upload_file_id=file_id)
|
||||
|
||||
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask):
|
||||
api = AccountAvatarApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "acc-a"
|
||||
user = make_account("acc-a")
|
||||
tenant_id = "tenant-1"
|
||||
file_id = "550e8400-e29b-41d4-a716-446655440001"
|
||||
|
||||
@ -190,10 +168,6 @@ class TestAccountAvatarApiGet:
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/account/avatar?avatar={file_id}"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file),
|
||||
patch(
|
||||
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
|
||||
@ -201,16 +175,15 @@ class TestAccountAvatarApiGet:
|
||||
) as sign_mock,
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
method(api, tenant_id, user)
|
||||
|
||||
sign_mock.assert_not_called()
|
||||
|
||||
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask):
|
||||
api = AccountAvatarApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "acc-owner"
|
||||
user = make_account("acc-owner")
|
||||
tenant_id = "tenant-1"
|
||||
file_id = "550e8400-e29b-41d4-a716-446655440002"
|
||||
|
||||
@ -222,10 +195,6 @@ class TestAccountAvatarApiGet:
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/account/avatar?avatar={file_id}"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch("controllers.console.workspace.account.db.session.scalar", return_value=upload_file),
|
||||
patch(
|
||||
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
|
||||
@ -233,31 +202,26 @@ class TestAccountAvatarApiGet:
|
||||
) as sign_mock,
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(api)
|
||||
method(api, tenant_id, user)
|
||||
|
||||
sign_mock.assert_not_called()
|
||||
|
||||
def test_get_avatar_https_pass_through_without_signing(self, app: Flask):
|
||||
api = AccountAvatarApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "acc-owner"
|
||||
user = make_account("acc-owner")
|
||||
tenant_id = "tenant-1"
|
||||
external = "https://cdn.example/avatar.png"
|
||||
|
||||
with (
|
||||
app.test_request_context(f"/account/avatar?avatar={external}"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant",
|
||||
return_value=(user, tenant_id),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.file_helpers.get_signed_file_url",
|
||||
return_value="https://signed/should-not-use",
|
||||
) as sign_mock,
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, tenant_id, user)
|
||||
|
||||
assert result == {"avatar_url": external}
|
||||
sign_mock.assert_not_called()
|
||||
@ -266,7 +230,7 @@ class TestAccountAvatarApiGet:
|
||||
class TestAccountPasswordApi:
|
||||
def test_password_success(self, app: Flask):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "old",
|
||||
@ -274,63 +238,51 @@ class TestAccountPasswordApi:
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "u1"
|
||||
user.name = "John"
|
||||
user.email = "john@test.com"
|
||||
user.avatar = "avatar.png"
|
||||
user.interface_language = "en-US"
|
||||
user.interface_theme = "light"
|
||||
user.timezone = "UTC"
|
||||
user.last_login_ip = "127.0.0.1"
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, user)
|
||||
|
||||
assert result["id"] == "u1"
|
||||
|
||||
def test_password_wrong_current(self, app: Flask):
|
||||
api = AccountPasswordApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"password": "bad",
|
||||
"new_password": "new123",
|
||||
"repeat_new_password": "new123",
|
||||
}
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.update_account_password",
|
||||
side_effect=ServicePwdError(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(CurrentPasswordIncorrectError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestAccountIntegrateApi:
|
||||
def test_get_integrates(self, app: Flask):
|
||||
api = AccountIntegrateApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
account = MagicMock(id="acc1")
|
||||
account = make_account("acc1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")),
|
||||
patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock,
|
||||
):
|
||||
scalars_mock.return_value.all.return_value = []
|
||||
result = method(api)
|
||||
result = method(api, account)
|
||||
|
||||
assert "data" in result
|
||||
assert len(result["data"]) == 2
|
||||
@ -339,13 +291,11 @@ class TestAccountIntegrateApi:
|
||||
class TestAccountDeleteApi:
|
||||
def test_delete_verify_success(self, app: Flask):
|
||||
api = AccountDeleteVerifyApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code",
|
||||
return_value=("token", "1234"),
|
||||
@ -355,43 +305,38 @@ class TestAccountDeleteApi:
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, user)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_delete_invalid_code(self, app: Flask):
|
||||
api = AccountDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"token": "t", "code": "x"}
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.account.AccountService.verify_account_deletion_code",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvalidAccountDeletionCodeError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestChangeEmailApis:
|
||||
def test_check_email_code_invalid(self, app: Flask):
|
||||
api = ChangeEmailCheckApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"email": "a@test.com", "code": "x", "token": "t"}
|
||||
user = make_account("acc-1")
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.account.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="acc-1"), "t1"),
|
||||
),
|
||||
patch.object(
|
||||
type(console_ns),
|
||||
"payload",
|
||||
@ -412,13 +357,14 @@ class TestChangeEmailApis:
|
||||
),
|
||||
):
|
||||
with pytest.raises(EmailCodeError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_reset_email_already_used(self, app: Flask):
|
||||
api = ChangeEmailResetApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"new_email": "x@test.com", "token": "t"}
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
@ -432,13 +378,13 @@ class TestChangeEmailApis:
|
||||
patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False),
|
||||
):
|
||||
with pytest.raises(EmailAlreadyInUseError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestCheckEmailUniqueApi:
|
||||
def test_email_unique_success(self, app: Flask):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"email": "ok@test.com"}
|
||||
|
||||
@ -459,7 +405,7 @@ class TestCheckEmailUniqueApi:
|
||||
|
||||
def test_email_in_freeze(self, app: Flask):
|
||||
api = CheckEmailUnique()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"email": "x@test.com"}
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
import inspect
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -18,31 +19,10 @@ from controllers.console.workspace.endpoint import (
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_and_tenant():
|
||||
return MagicMock(id="u1"), "t1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_current_account(user_and_tenant):
|
||||
with patch(
|
||||
"controllers.console.workspace.endpoint.current_account_with_tenant",
|
||||
return_value=user_and_tenant,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointCollectionApi:
|
||||
def test_create_success(self, app: Flask):
|
||||
api = EndpointCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
@ -54,13 +34,13 @@ class TestEndpointCollectionApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_create_permission_denied(self, app: Flask):
|
||||
api = EndpointCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
@ -76,11 +56,11 @@ class TestEndpointCollectionApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
def test_create_validation_error(self, app: Flask):
|
||||
api = EndpointCollectionApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "p1",
|
||||
@ -92,14 +72,13 @@ class TestEndpointCollectionApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestDeprecatedEndpointCreateApi:
|
||||
def test_create_success(self, app: Flask):
|
||||
api = DeprecatedEndpointCreateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"plugin_unique_identifier": "plugin-1",
|
||||
@ -111,42 +90,40 @@ class TestDeprecatedEndpointCreateApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListApi:
|
||||
def test_list_success(self, app: Flask):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert "endpoints" in result
|
||||
assert len(result["endpoints"]) == 1
|
||||
|
||||
def test_list_invalid_query(self, app: Flask):
|
||||
api = EndpointListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=0&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointListForSinglePluginApi:
|
||||
def test_list_for_plugin_success(self, app: Flask):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10&plugin_id=p1"),
|
||||
@ -155,26 +132,25 @@ class TestEndpointListForSinglePluginApi:
|
||||
return_value=[{"id": "e1"}],
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert "endpoints" in result
|
||||
|
||||
def test_list_for_plugin_missing_param(self, app: Flask):
|
||||
api = EndpointListForSinglePluginApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/?page=1&page_size=10"),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointItemApi:
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
@ -183,26 +159,26 @@ class TestEndpointItemApi:
|
||||
return_value=True,
|
||||
) as mock_delete,
|
||||
):
|
||||
result = method(api, "e1")
|
||||
result = method(api, "t1", "u1", "e1")
|
||||
|
||||
assert result["success"] is True
|
||||
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
|
||||
|
||||
def test_delete_service_failure(self, app: Flask):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.delete)
|
||||
method = inspect.unwrap(api.delete)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", method="DELETE"),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api, "e1")
|
||||
result = method(api, "t1", "u1", "e1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
def test_update_success(self, app: Flask):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "new-name",
|
||||
@ -216,7 +192,7 @@ class TestEndpointItemApi:
|
||||
return_value=True,
|
||||
) as mock_update,
|
||||
):
|
||||
result = method(api, "e1")
|
||||
result = method(api, "t1", "u1", "e1")
|
||||
|
||||
assert result["success"] is True
|
||||
mock_update.assert_called_once_with(
|
||||
@ -229,7 +205,7 @@ class TestEndpointItemApi:
|
||||
|
||||
def test_update_validation_error(self, app: Flask):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
payload = {"settings": {}}
|
||||
|
||||
@ -237,11 +213,11 @@ class TestEndpointItemApi:
|
||||
app.test_request_context("/", method="PATCH", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api, "e1")
|
||||
method(api, "t1", "u1", "e1")
|
||||
|
||||
def test_update_service_failure(self, app: Flask):
|
||||
api = EndpointItemApi()
|
||||
method = unwrap(api.patch)
|
||||
method = inspect.unwrap(api.patch)
|
||||
|
||||
payload = {
|
||||
"name": "n",
|
||||
@ -252,16 +228,15 @@ class TestEndpointItemApi:
|
||||
app.test_request_context("/", method="PATCH", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api, "e1")
|
||||
result = method(api, "t1", "u1", "e1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestDeprecatedEndpointDeleteApi:
|
||||
def test_delete_success(self, app: Flask):
|
||||
api = DeprecatedEndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
@ -269,23 +244,23 @@ class TestDeprecatedEndpointDeleteApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_delete_invalid_payload(self, app: Flask):
|
||||
api = DeprecatedEndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
def test_delete_service_failure(self, app: Flask):
|
||||
api = DeprecatedEndpointDeleteApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
@ -293,16 +268,15 @@ class TestDeprecatedEndpointDeleteApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestDeprecatedEndpointUpdateApi:
|
||||
def test_update_success(self, app: Flask):
|
||||
api = DeprecatedEndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
@ -314,13 +288,13 @@ class TestDeprecatedEndpointUpdateApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_update_validation_error(self, app: Flask):
|
||||
api = DeprecatedEndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1", "settings": {}}
|
||||
|
||||
@ -328,11 +302,11 @@ class TestDeprecatedEndpointUpdateApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
def test_update_service_failure(self, app: Flask):
|
||||
api = DeprecatedEndpointUpdateApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {
|
||||
"endpoint_id": "e1",
|
||||
@ -344,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@ -379,11 +353,10 @@ class TestEndpointRouteMetadata:
|
||||
assert route_map["DeprecatedEndpointUpdateApi"] == ("/workspaces/current/endpoints/update",)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointEnableApi:
|
||||
def test_enable_success(self, app: Flask):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
@ -391,23 +364,23 @@ class TestEndpointEnableApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_enable_invalid_payload(self, app: Flask):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
def test_enable_service_failure(self, app: Flask):
|
||||
api = EndpointEnableApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
@ -415,16 +388,15 @@ class TestEndpointEnableApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("patch_current_account")
|
||||
class TestEndpointDisableApi:
|
||||
def test_disable_success(self, app: Flask):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"endpoint_id": "e1"}
|
||||
|
||||
@ -432,16 +404,16 @@ class TestEndpointDisableApi:
|
||||
app.test_request_context("/", json=payload),
|
||||
patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1", "u1")
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
def test_disable_invalid_payload(self, app: Flask):
|
||||
api = EndpointDisableApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json={}),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, "t1", "u1")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -28,38 +29,47 @@ from controllers.console.workspace.workspace import (
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import TenantStatus
|
||||
from models.account import Account, Tenant, TenantCustomConfigDict, TenantStatus
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
def make_account(account_id: str = "u1") -> Account:
|
||||
account = Account(name="Test User", email=f"{account_id}@example.com")
|
||||
account.id = account_id
|
||||
return account
|
||||
|
||||
|
||||
def make_tenant(
|
||||
tenant_id: str = "t1",
|
||||
*,
|
||||
name: str | None = None,
|
||||
status: TenantStatus = TenantStatus.NORMAL,
|
||||
custom_config: TenantCustomConfigDict | None = None,
|
||||
) -> Tenant:
|
||||
tenant = Tenant(name=name or f"Tenant {tenant_id}", status=status)
|
||||
tenant.id = tenant_id
|
||||
tenant.created_at = naive_utc_now()
|
||||
if custom_config is not None:
|
||||
tenant.custom_config_dict = custom_config
|
||||
return tenant
|
||||
|
||||
|
||||
def make_account_with_tenant(tenant: Tenant) -> Account:
|
||||
account = make_account()
|
||||
account._current_tenant = tenant
|
||||
return account
|
||||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success_saas_path(self, app: Flask):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant1 = make_tenant("t1", name="Tenant 1")
|
||||
tenant2 = make_tenant("t2", name="Tenant 2")
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
@ -76,7 +86,7 @@ class TestTenantListApi:
|
||||
) as get_plan_bulk_mock,
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "t1", user)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
@ -93,30 +103,18 @@ class TestTenantListApi:
|
||||
(SaaS contract treats enabled as on; display follows subscription.plan).
|
||||
"""
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant1 = make_tenant("t1", name="Tenant 1")
|
||||
tenant2 = make_tenant("t2", name="Tenant 2")
|
||||
|
||||
features_t2 = MagicMock()
|
||||
features_t2.billing.enabled = False
|
||||
features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
@ -133,7 +131,7 @@ class TestTenantListApi:
|
||||
return_value=features_t2,
|
||||
) as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "t1", user)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
@ -148,30 +146,18 @@ class TestTenantListApi:
|
||||
so we simulate the real failure mode by returning empty dict for non-empty input.
|
||||
"""
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant1 = make_tenant("t1", name="Tenant 1")
|
||||
tenant2 = make_tenant("t2", name="Tenant 2")
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.TEAM
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
@ -189,7 +175,7 @@ class TestTenantListApi:
|
||||
) as get_features_mock,
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "t2", user)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
@ -200,25 +186,17 @@ class TestTenantListApi:
|
||||
|
||||
def test_get_billing_disabled_community_path(self, app: Flask):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant = make_tenant("t1", name="Tenant")
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
@ -231,7 +209,7 @@ class TestTenantListApi:
|
||||
return_value=features,
|
||||
) as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "t1", user)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
@ -239,26 +217,14 @@ class TestTenantListApi:
|
||||
|
||||
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=naive_utc_now(),
|
||||
)
|
||||
tenant1 = make_tenant("t1", name="Tenant 1")
|
||||
tenant2 = make_tenant("t2", name="Tenant 2")
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
@ -268,7 +234,7 @@ class TestTenantListApi:
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "t2", user)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
@ -279,13 +245,11 @@ class TestTenantListApi:
|
||||
|
||||
def test_get_enterprise_only_with_empty_tenants(self, app: Flask):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[],
|
||||
@ -295,7 +259,7 @@ class TestTenantListApi:
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, None, user)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"] == []
|
||||
@ -305,9 +269,9 @@ class TestTenantListApi:
|
||||
class TestWorkspaceListApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now())
|
||||
tenant = make_tenant("t1", name="T")
|
||||
paginate_result = MagicMock(items=[tenant], has_next=False, total=1)
|
||||
|
||||
with (
|
||||
@ -322,9 +286,9 @@ class TestWorkspaceListApi:
|
||||
|
||||
def test_get_has_next_true(self, app: Flask):
|
||||
api = WorkspaceListApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now())
|
||||
tenant = make_tenant("t1", name="T")
|
||||
paginate_result = MagicMock(items=[tenant], has_next=True, total=10)
|
||||
|
||||
with (
|
||||
@ -340,80 +304,71 @@ class TestWorkspaceListApi:
|
||||
class TestTenantApi:
|
||||
def test_post_active_tenant(self, app: Flask):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
tenant = make_tenant()
|
||||
user = make_account_with_tenant(tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert status == 200
|
||||
assert result["id"] == "t1"
|
||||
|
||||
def test_post_archived_with_switch(self, app: Flask):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
archived = MagicMock(status=TenantStatus.ARCHIVE)
|
||||
new_tenant = MagicMock(status="active")
|
||||
|
||||
user = MagicMock(current_tenant=archived)
|
||||
archived = make_tenant(status=TenantStatus.ARCHIVE)
|
||||
new_tenant = make_tenant("new")
|
||||
user = make_account_with_tenant(archived)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"}
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert result["id"] == "new"
|
||||
|
||||
def test_post_archived_no_tenant(self, app: Flask):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE))
|
||||
user = make_account_with_tenant(make_tenant(status=TenantStatus.ARCHIVE))
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/current"),
|
||||
patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")),
|
||||
patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]),
|
||||
):
|
||||
with pytest.raises(Unauthorized):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_post_info_path(self, app: Flask):
|
||||
api = TenantApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(status="active")
|
||||
user = MagicMock(current_tenant=tenant)
|
||||
tenant = make_tenant()
|
||||
user = make_account_with_tenant(tenant)
|
||||
|
||||
with (
|
||||
app.test_request_context("/info"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(user, "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info",
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.logger.warning") as warn_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
warn_mock.assert_called_once()
|
||||
assert status == 200
|
||||
@ -456,16 +411,14 @@ class TestTenantInfoResponse:
|
||||
class TestSwitchWorkspaceApi:
|
||||
def test_switch_success(self, app: Flask):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "t2"}
|
||||
tenant = MagicMock(id="t2")
|
||||
tenant = make_tenant("t2")
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
patch(
|
||||
@ -473,85 +426,73 @@ class TestSwitchWorkspaceApi:
|
||||
),
|
||||
):
|
||||
get_mock.return_value = tenant
|
||||
result = method(api)
|
||||
result = method(api, user)
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_switch_not_linked(self, app: Flask):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "bad"}
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception),
|
||||
):
|
||||
with pytest.raises(AccountNotLinkTenantError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_switch_tenant_not_found(self, app: Flask):
|
||||
api = SwitchWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"tenant_id": "missing"}
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/switch", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.TenantService.switch_tenant"),
|
||||
patch("controllers.console.workspace.workspace.db.session.get") as get_mock,
|
||||
):
|
||||
get_mock.return_value = None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestCustomConfigWorkspaceApi:
|
||||
def test_post_success(self, app: Flask):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={})
|
||||
tenant = make_tenant(custom_config={})
|
||||
|
||||
payload = {"remove_webapp_brand": True}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"}
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_logo_fallback(self, app: Flask):
|
||||
api = CustomConfigWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"})
|
||||
tenant = make_tenant(custom_config={"replace_webapp_logo": "old-logo"})
|
||||
|
||||
payload = {"remove_webapp_brand": False}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/custom-config", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.db.get_or_404",
|
||||
return_value=tenant,
|
||||
@ -562,7 +503,7 @@ class TestCustomConfigWorkspaceApi:
|
||||
return_value={"id": "t1"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo"
|
||||
assert result["result"] == "success"
|
||||
@ -571,54 +512,41 @@ class TestCustomConfigWorkspaceApi:
|
||||
class TestWebappLogoWorkspaceApi:
|
||||
def test_no_file(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/upload", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_too_many_files(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
data = {
|
||||
"file": MagicMock(),
|
||||
"extra": MagicMock(),
|
||||
}
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data=data),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/upload", data=data):
|
||||
with pytest.raises(TooManyFilesError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_invalid_extension(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
file = MagicMock(filename="test.txt")
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context("/upload", data={"file": file}),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/upload", data={"file": file}):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_upload_success(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
@ -627,6 +555,7 @@ class TestWebappLogoWorkspaceApi:
|
||||
)
|
||||
|
||||
upload = MagicMock(id="file1")
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -634,53 +563,46 @@ class TestWebappLogoWorkspaceApi:
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
mock_db.engine = MagicMock()
|
||||
fs.return_value.upload_file.return_value = upload
|
||||
|
||||
result, status = method(api)
|
||||
result, status = method(api, user)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "file1"
|
||||
|
||||
def test_filename_missing(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"data"),
|
||||
filename="",
|
||||
content_type="image/png",
|
||||
)
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
with app.test_request_context(
|
||||
"/upload",
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_file_too_large(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -688,10 +610,6 @@ class TestWebappLogoWorkspaceApi:
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
@ -699,17 +617,18 @@ class TestWebappLogoWorkspaceApi:
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big")
|
||||
|
||||
with pytest.raises(FileTooLargeError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
def test_service_unsupported_file(self, app: Flask):
|
||||
api = WebappLogoWorkspaceApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
file = FileStorage(
|
||||
stream=BytesIO(b"x"),
|
||||
filename="logo.png",
|
||||
content_type="image/png",
|
||||
)
|
||||
user = make_account()
|
||||
|
||||
with (
|
||||
app.test_request_context(
|
||||
@ -717,10 +636,6 @@ class TestWebappLogoWorkspaceApi:
|
||||
data={"file": file},
|
||||
content_type="multipart/form-data",
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "t1"),
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.FileService") as fs,
|
||||
patch("controllers.console.workspace.workspace.db") as mock_db,
|
||||
):
|
||||
@ -728,23 +643,20 @@ class TestWebappLogoWorkspaceApi:
|
||||
fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError()
|
||||
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
method(api)
|
||||
method(api, user)
|
||||
|
||||
|
||||
class TestWorkspaceInfoApi:
|
||||
def test_post_success(self, app: Flask):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
tenant = MagicMock()
|
||||
tenant = make_tenant()
|
||||
|
||||
payload = {"name": "New Name"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant),
|
||||
patch("controllers.console.workspace.workspace.db.session.commit"),
|
||||
patch(
|
||||
@ -752,31 +664,27 @@ class TestWorkspaceInfoApi:
|
||||
return_value={"name": "New Name"},
|
||||
),
|
||||
):
|
||||
result = method(api)
|
||||
result = method(api, "t1")
|
||||
|
||||
assert result["result"] == "success"
|
||||
|
||||
def test_no_current_tenant(self, app: Flask):
|
||||
api = WorkspaceInfoApi()
|
||||
method = unwrap(api.post)
|
||||
method = inspect.unwrap(api.post)
|
||||
|
||||
payload = {"name": "X"}
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces/info", json=payload),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, None)
|
||||
|
||||
|
||||
class TestWorkspacePermissionApi:
|
||||
def test_get_success(self, app: Flask):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
permission = MagicMock(
|
||||
workspace_id="t1",
|
||||
@ -786,29 +694,20 @@ class TestWorkspacePermissionApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission",
|
||||
return_value=permission,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "t1")
|
||||
|
||||
assert status == 200
|
||||
assert result["workspace_id"] == "t1"
|
||||
|
||||
def test_no_current_tenant(self, app: Flask):
|
||||
api = WorkspacePermissionApi()
|
||||
method = unwrap(api.get)
|
||||
method = inspect.unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/permission"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/permission"):
|
||||
with pytest.raises(ValueError):
|
||||
method(api)
|
||||
method(api, None)
|
||||
|
||||
@ -82,11 +82,13 @@ This test suite follows a comprehensive testing strategy that covers:
|
||||
================================================================================
|
||||
"""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask import Flask, g
|
||||
from flask.testing import FlaskClient
|
||||
from flask_restx import Api
|
||||
|
||||
@ -95,6 +97,7 @@ from controllers.console.datasets.external import (
|
||||
ExternalApiTemplateListApi,
|
||||
)
|
||||
from controllers.console.datasets.hit_testing import HitTestingApi
|
||||
from models.account import Account, AccountStatus, TenantAccountRole
|
||||
from models.dataset import Dataset, DatasetPermissionEnum
|
||||
|
||||
# ============================================================================
|
||||
@ -817,15 +820,32 @@ class TestExternalDatasetApi:
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user(self):
|
||||
"""Mock current user and tenant context."""
|
||||
with patch("controllers.console.datasets.external.current_account_with_tenant") as mock_get_user:
|
||||
mock_user = ControllerApiTestDataFactory.create_user_mock(is_dataset_editor=True)
|
||||
def mock_current_account_context(self, app: Flask) -> Iterator[Mock]:
|
||||
"""Provide the wrapper auth context required by HTTP-client controller tests."""
|
||||
mock_user = Account(name="Test User", email="user-123@example.com")
|
||||
mock_user.id = "user-123"
|
||||
mock_user.status = AccountStatus.ACTIVE
|
||||
mock_user.role = TenantAccountRole.EDITOR
|
||||
|
||||
def load_user_from_request_context() -> None:
|
||||
g._login_user = mock_user
|
||||
|
||||
setattr( # noqa: B010
|
||||
app,
|
||||
"login_manager",
|
||||
SimpleNamespace(load_user_from_request_context=load_user_from_request_context),
|
||||
)
|
||||
|
||||
with (
|
||||
patch("controllers.console.wraps.current_account_with_tenant") as mock_get_user,
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"),
|
||||
patch("libs.login.check_csrf_token", return_value=None),
|
||||
):
|
||||
mock_tenant_id = "tenant-123"
|
||||
mock_get_user.return_value = (mock_user, mock_tenant_id)
|
||||
yield mock_get_user
|
||||
|
||||
def test_get_external_knowledge_apis_success(self, client_list, mock_current_user):
|
||||
def test_get_external_knowledge_apis_success(self, client_list: FlaskClient, mock_current_account_context: Mock):
|
||||
"""
|
||||
Test successful retrieval of external knowledge API list.
|
||||
|
||||
@ -839,7 +859,11 @@ class TestExternalDatasetApi:
|
||||
- Status code is 200
|
||||
"""
|
||||
# Arrange
|
||||
apis = [{"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"} for i in range(3)]
|
||||
apis = []
|
||||
for i in range(3):
|
||||
api_item = Mock()
|
||||
api_item.to_dict.return_value = {"id": f"api-{i}", "name": f"API {i}", "endpoint": f"https://api{i}.com"}
|
||||
apis.append(api_item)
|
||||
|
||||
with patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis"
|
||||
@ -847,6 +871,7 @@ class TestExternalDatasetApi:
|
||||
mock_get_apis.return_value = (apis, 3)
|
||||
|
||||
# Act
|
||||
# TODO: this should be made integrated tests...
|
||||
response = client_list.get("/datasets/external-knowledge-api?page=1&limit=20")
|
||||
|
||||
# Assert
|
||||
|
||||
Loading…
Reference in New Issue
Block a user