From 2a46a7d91d957c2a08756f551689a381c0fab720 Mon Sep 17 00:00:00 2001 From: chariri Date: Thu, 11 Jun 2026 10:30:31 +0900 Subject: [PATCH] refactor(api): migrate remaining console APIs to use injected user/tenant (#37288) --- .../console/datasets/hit_testing.py | 14 +- .../console/datasets/hit_testing_base.py | 17 +- api/controllers/console/datasets/metadata.py | 15 +- .../datasets/rag_pipeline/rag_pipeline.py | 29 +- .../rag_pipeline/rag_pipeline_workflow.py | 13 +- api/controllers/console/explore/completion.py | 13 +- .../console/explore/conversation.py | 29 +- api/controllers/console/explore/trial.py | 35 +- api/controllers/console/feature.py | 20 +- .../console/workspace/trigger_providers.py | 150 +++----- api/libs/login.py | 48 +++ api/services/metadata_service.py | 34 +- .../built_in/built_in_retrieval.py | 3 +- .../customized/customized_retrieval.py | 6 +- .../database/database_retrieval.py | 3 +- .../pipeline_template_base.py | 2 +- .../remote/remote_retrieval.py | 3 +- api/services/rag_pipeline/rag_pipeline.py | 44 ++- .../rag_pipeline/test_rag_pipeline.py | 24 +- .../test_rag_pipeline_workflow.py | 125 ++++-- .../console/explore/test_conversation.py | 133 ++++--- .../workspace/test_trigger_providers.py | 158 ++++---- .../test_rag_pipeline_service_db.py | 110 +++--- .../services/test_metadata_partial_update.py | 44 ++- .../services/test_metadata_service.py | 357 ++++++++---------- .../rag_pipeline/test_rag_pipeline.py | 100 +++-- .../console/datasets/test_hit_testing.py | 36 +- .../console/datasets/test_hit_testing_base.py | 76 ++-- .../console/datasets/test_metadata.py | 19 +- .../console/explore/test_completion.py | 76 ++-- .../controllers/console/explore/test_trial.py | 352 +++++++---------- .../controllers/console/test_feature.py | 77 ++-- .../controllers/console/test_wraps.py | 3 + .../service_api/dataset/test_hit_testing.py | 52 ++- api/tests/unit_tests/libs/test_login.py | 98 ++++- .../test_customized_retrieval.py | 6 +- .../rag_pipeline/test_rag_pipeline_service.py | 254 ++++++++----- .../services/test_metadata_bug_complete.py | 80 ++-- .../services/test_metadata_nullable_bug.py | 55 ++- 39 files changed, 1448 insertions(+), 1265 deletions(-) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 37640138eb3..c08ed2fe9f0 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -8,6 +8,7 @@ from controllers.common.schema import register_response_schema_models, register_ from fields.hit_testing_fields import HitTestingResponse from libs.helper import dump_response from libs.login import login_required +from models import Account from .. import console_ns from ..datasets.hit_testing_base import DatasetsHitTestingBase, HitTestingPayload @@ -15,6 +16,8 @@ from ..wraps import ( account_initialization_required, cloud_edition_billing_rate_limit_check, setup_required, + with_current_tenant_id, + with_current_user, ) register_schema_models(console_ns, HitTestingPayload) @@ -38,11 +41,16 @@ class HitTestingApi(Resource, DatasetsHitTestingBase): @login_required @account_initialization_required @cloud_edition_billing_rate_limit_check("knowledge") - def post(self, dataset_id: UUID) -> dict[str, object]: + @with_current_tenant_id + @with_current_user + def post(self, current_user: Account, current_tenant_id: str, dataset_id: UUID) -> dict[str, object]: dataset_id_str = str(dataset_id) - dataset = self.get_and_validate_dataset(dataset_id_str) + dataset = self.get_and_validate_dataset(dataset_id_str, current_user, current_tenant_id) args = self.parse_args(console_ns.payload) self.hit_testing_args_check(args) - return dump_response(HitTestingResponse, self.perform_hit_testing(dataset, args)) + return dump_response( + HitTestingResponse, + self.perform_hit_testing(dataset, args, current_user, current_tenant_id), + ) diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 4be91e0e54d..6141d2d1d58 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -19,7 +19,7 @@ from core.errors.error import ( QuotaExceededError, ) from graphon.model_runtime.errors.invoke import InvokeError -from libs.login import current_user +from libs.login import resolve_account_fallback from models.account import Account from models.dataset import Dataset from services.dataset_service import DatasetService @@ -71,8 +71,10 @@ class DatasetsHitTestingBase: return normalized_records @staticmethod - def get_and_validate_dataset(dataset_id: str) -> Dataset: - assert isinstance(current_user, Account) + def get_and_validate_dataset( + dataset_id: str, current_user: Account | None = None, current_tenant_id: str | None = None + ) -> Dataset: + current_user, _ = resolve_account_fallback(current_user, current_tenant_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -95,9 +97,14 @@ class DatasetsHitTestingBase: return hit_testing_payload.model_dump(exclude_none=True) @staticmethod - def perform_hit_testing(dataset: Dataset, args: dict[str, Any]) -> dict[str, Any]: - assert isinstance(current_user, Account) + def perform_hit_testing( + dataset: Dataset, + args: dict[str, Any], + current_user: Account | None = None, + current_tenant_id: str | None = None, + ) -> dict[str, Any]: try: + current_user, _ = resolve_account_fallback(current_user, current_tenant_id) response = HitTestingService.retrieve( dataset=dataset, query=cast(str, args.get("query")), diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 5b53c40ae97..ec4c5bedb61 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -11,6 +11,7 @@ from controllers.console.wraps import ( account_initialization_required, enterprise_license_required, setup_required, + with_current_tenant_id, with_current_user, ) from fields.dataset_fields import ( @@ -50,7 +51,8 @@ class DatasetMetadataCreateApi(Resource): @console_ns.response(201, "Metadata created successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataArgs.__name__]) @with_current_user - def post(self, current_user: Account, dataset_id: UUID): + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): metadata_args = MetadataArgs.model_validate(console_ns.payload or {}) dataset_id_str = str(dataset_id) @@ -59,7 +61,7 @@ class DatasetMetadataCreateApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.create_metadata(dataset_id_str, metadata_args) + metadata = MetadataService.create_metadata(dataset_id_str, metadata_args, current_user, current_tenant_id) return dump_response(DatasetMetadataResponse, metadata), 201 @setup_required @@ -87,7 +89,8 @@ class DatasetMetadataApi(Resource): @console_ns.response(200, "Metadata updated successfully", console_ns.models[DatasetMetadataResponse.__name__]) @console_ns.expect(console_ns.models[MetadataUpdatePayload.__name__]) @with_current_user - def patch(self, current_user: Account, dataset_id: UUID, metadata_id: UUID): + @with_current_tenant_id + def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, metadata_id: UUID): payload = MetadataUpdatePayload.model_validate(console_ns.payload or {}) name = payload.name @@ -98,7 +101,9 @@ class DatasetMetadataApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, name) + metadata = MetadataService.update_metadata_name( + dataset_id_str, metadata_id_str, name, current_user, current_tenant_id + ) return dump_response(DatasetMetadataResponse, metadata), 200 @setup_required @@ -181,7 +186,7 @@ class DocumentMetadataEditApi(Resource): metadata_args = MetadataOperationData.model_validate(console_ns.payload or {}) - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(dataset, metadata_args, current_user) # Frontend callers only await success and invalidate caches; no response body is consumed. return "", 204 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index d21800d53c1..ca41573cb85 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -21,11 +21,14 @@ from controllers.console.wraps import ( enterprise_license_required, knowledge_pipeline_publish_enabled, setup_required, + with_current_tenant_id, + with_current_user, ) from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import dump_response from libs.login import login_required +from models.account import Account from models.dataset import PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -96,10 +99,11 @@ class PipelineTemplateListApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def get(self) -> JsonResponseWithStatus: + @with_current_tenant_id + def get(self, current_tenant_id: str) -> JsonResponseWithStatus: query = PipelineTemplateListQuery.model_validate(request.args.to_dict(flat=True)) # get pipeline templates - pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language) + pipeline_templates = RagPipelineService.get_pipeline_templates(query.type, query.language, current_tenant_id) return dump_response(PipelineTemplateListResponse, pipeline_templates), 200 @@ -128,10 +132,14 @@ class CustomizedPipelineTemplateApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def patch(self, template_id: str) -> tuple[str, int]: + @with_current_user + @with_current_tenant_id + def patch(self, current_tenant_id: str, current_user: Account, template_id: str) -> tuple[str, int]: payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {}) pipeline_template_info = PipelineTemplateInfoEntity.model_validate(payload.model_dump()) - RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) + RagPipelineService.update_customized_pipeline_template( + template_id, pipeline_template_info, current_user, current_tenant_id + ) return "", 204 @console_ns.response(204, "Pipeline template deleted") @@ -139,8 +147,9 @@ class CustomizedPipelineTemplateApi(Resource): @login_required @account_initialization_required @enterprise_license_required - def delete(self, template_id: str) -> tuple[str, int]: - RagPipelineService.delete_customized_pipeline_template(template_id) + @with_current_tenant_id + def delete(self, current_tenant_id: str, template_id: str) -> tuple[str, int]: + RagPipelineService.delete_customized_pipeline_template(template_id, current_tenant_id) return "", 204 @setup_required @@ -168,8 +177,12 @@ class PublishCustomizedPipelineTemplateApi(Resource): @account_initialization_required @enterprise_license_required @knowledge_pipeline_publish_enabled - def post(self, pipeline_id: str) -> tuple[str, int]: + @with_current_user + @with_current_tenant_id + def post(self, current_tenant_id: str, current_user: Account, pipeline_id: str) -> tuple[str, int]: payload = CustomizedPipelineTemplatePayload.model_validate(console_ns.payload or {}) rag_pipeline_service = RagPipelineService() - rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, payload.model_dump()) + rag_pipeline_service.publish_customized_pipeline_template( + pipeline_id, payload.model_dump(), current_user, current_tenant_id + ) return "", 204 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index ebc3b92dd63..53e0d0c2931 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -48,7 +48,7 @@ from fields.workflow_run_fields import ( from graphon.model_runtime.utils.encoders import jsonable_encoder from libs import helper from libs.helper import TimestampField, UUIDStrOrEmpty, dump_response -from libs.login import current_user, login_required +from libs.login import login_required from models import Account from models.dataset import Pipeline from models.model import EndUser @@ -835,7 +835,7 @@ class RagPipelineWorkflowRunListApi(Resource): } ) args = { - "last_id": str(query.last_id) if query.last_id else None, + "last_id": query.last_id or None, "limit": query.limit, } @@ -881,7 +881,8 @@ class RagPipelineWorkflowRunNodeExecutionListApi(Resource): @login_required @account_initialization_required @get_rag_pipeline - def get(self, pipeline: Pipeline, run_id: UUID): + @with_current_user + def get(self, current_user: Account, pipeline: Pipeline, run_id: UUID): """ Get workflow run node execution list """ @@ -988,9 +989,11 @@ class RagPipelineRecommendedPluginApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, current_user: Account): query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict()) rag_pipeline_service = RagPipelineService() - recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type) + recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type, current_user, current_tenant_id) return recommended_plugins diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index d1ae6526c68..1db177a29dd 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -18,7 +18,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource -from controllers.console.wraps import with_current_user_id +from controllers.console.wraps import with_current_user, with_current_user_id from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ( @@ -30,7 +30,6 @@ from extensions.ext_database import db from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.datetime_utils import naive_utc_now -from libs.login import current_user from models import Account from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService @@ -84,7 +83,8 @@ register_response_schema_models(console_ns, SimpleResultResponse) ) class CompletionApi(InstalledAppResource): @console_ns.expect(console_ns.models[CompletionMessageExplorePayload.__name__]) - def post(self, installed_app: InstalledApp): + @with_current_user + def post(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -101,8 +101,6 @@ class CompletionApi(InstalledAppResource): db.session.commit() try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) @@ -160,7 +158,8 @@ class CompletionStopApi(InstalledAppResource): ) class ChatApi(InstalledAppResource): @console_ns.expect(console_ns.models[ChatMessagePayload.__name__]) - def post(self, installed_app: InstalledApp): + @with_current_user + def post(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -177,8 +176,6 @@ class ChatApi(InstalledAppResource): db.session.commit() try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 68e18a0207b..9cebba496b5 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -11,6 +11,7 @@ from controllers.common.schema import register_response_schema_models, register_ from controllers.console.app.error import AppUnavailableError from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource +from controllers.console.wraps import with_current_user from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( @@ -19,7 +20,6 @@ from fields.conversation_fields import ( SimpleConversation, ) from libs.helper import UUIDStrOrEmpty -from libs.login import current_user from models import Account from models.model import AppMode, InstalledApp from services.conversation_service import ConversationService @@ -45,7 +45,8 @@ register_response_schema_models(console_ns, ResultResponse) ) class ConversationListApi(InstalledAppResource): @console_ns.expect(console_ns.models[ConversationListQuery.__name__]) - def get(self, installed_app: InstalledApp): + @with_current_user + def get(self, current_user: Account, installed_app: InstalledApp): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -66,14 +67,12 @@ class ConversationListApi(InstalledAppResource): args = ConversationListQuery.model_validate(raw_args) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") with sessionmaker(db.engine).begin() as session: pagination = WebConversationService.pagination_by_last_id( session=session, app_model=app_model, user=current_user, - last_id=str(args.last_id) if args.last_id else None, + last_id=args.last_id or None, limit=args.limit, invoke_from=InvokeFrom.EXPLORE, pinned=args.pinned, @@ -95,7 +94,8 @@ class ConversationListApi(InstalledAppResource): ) class ConversationApi(InstalledAppResource): @console_ns.response(204, "Conversation deleted successfully") - def delete(self, installed_app: InstalledApp, c_id: UUID): + @with_current_user + def delete(self, current_user: Account, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -105,8 +105,6 @@ class ConversationApi(InstalledAppResource): conversation_id = str(c_id) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") ConversationService.delete(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -120,7 +118,8 @@ class ConversationApi(InstalledAppResource): ) class ConversationRenameApi(InstalledAppResource): @console_ns.expect(console_ns.models[ConversationRenamePayload.__name__]) - def post(self, installed_app: InstalledApp, c_id: UUID): + @with_current_user + def post(self, current_user: Account, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -133,8 +132,6 @@ class ConversationRenameApi(InstalledAppResource): payload = ConversationRenamePayload.model_validate(console_ns.payload or {}) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") conversation = ConversationService.rename( app_model, conversation_id, current_user, payload.name, payload.auto_generate ) @@ -153,7 +150,8 @@ class ConversationRenameApi(InstalledAppResource): ) class ConversationPinApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def patch(self, installed_app: InstalledApp, c_id: UUID): + @with_current_user + def patch(self, current_user: Account, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -164,8 +162,6 @@ class ConversationPinApi(InstalledAppResource): conversation_id = str(c_id) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") WebConversationService.pin(app_model, conversation_id, current_user) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -179,7 +175,8 @@ class ConversationPinApi(InstalledAppResource): ) class ConversationUnPinApi(InstalledAppResource): @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def patch(self, installed_app: InstalledApp, c_id: UUID): + @with_current_user + def patch(self, current_user: Account, installed_app: InstalledApp, c_id: UUID): app_model = installed_app.app if app_model is None: raise AppUnavailableError() @@ -188,8 +185,6 @@ class ConversationUnPinApi(InstalledAppResource): raise NotChatAppError() conversation_id = str(c_id) - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") WebConversationService.unpin(app_model, conversation_id, current_user) return ResultResponse(result="success").model_dump(mode="json") diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 26b48ec599a..2e7796574f1 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -33,6 +33,7 @@ from controllers.console.explore.error import ( NotWorkflowAppError, ) from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable +from controllers.console.wraps import with_current_user from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from core.app.apps.base_app_queue_manager import AppQueueManager @@ -63,7 +64,6 @@ from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value -from libs.login import current_user from models import Account from models.account import TenantStatus from models.model import AppMode, Site @@ -155,7 +155,8 @@ register_schema_models(console_ns, WorkflowRunRequest, ChatRequest, TextToSpeech class TrialAppWorkflowRunApi(TrialAppResource): @trial_feature_enable @console_ns.expect(console_ns.models[WorkflowRunRequest.__name__]) - def post(self, trial_app): + @with_current_user + def post(self, current_user: Account, trial_app): """ Run workflow """ @@ -168,7 +169,6 @@ class TrialAppWorkflowRunApi(TrialAppResource): request_data = WorkflowRunRequest.model_validate(console_ns.payload) args = request_data.model_dump() - assert current_user is not None try: app_id = app_model.id user_id = current_user.id @@ -206,7 +206,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource): app_mode = AppMode.value_of(app_model.mode) if app_mode != AppMode.WORKFLOW: raise NotWorkflowAppError() - assert current_user is not None # Stop using both mechanisms for backward compatibility # Legacy stop flag mechanism (without user check) @@ -221,7 +220,8 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource): class TrialChatApi(TrialAppResource): @console_ns.expect(console_ns.models[ChatRequest.__name__]) @trial_feature_enable - def post(self, trial_app): + @with_current_user + def post(self, current_user: Account, trial_app): app_model = trial_app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -239,9 +239,6 @@ class TrialChatApi(TrialAppResource): args["auto_generate_name"] = False try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - # Get IDs before they might be detached from session app_id = app_model.id user_id = current_user.id @@ -276,7 +273,8 @@ class TrialChatApi(TrialAppResource): class TrialMessageSuggestedQuestionApi(TrialAppResource): - def get(self, trial_app, message_id): + @with_current_user + def get(self, current_user: Account, trial_app, message_id): app_model = trial_app app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: @@ -285,8 +283,6 @@ class TrialMessageSuggestedQuestionApi(TrialAppResource): message_id = str(message_id) try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) @@ -313,15 +309,13 @@ class TrialMessageSuggestedQuestionApi(TrialAppResource): class TrialChatAudioApi(TrialAppResource): @trial_feature_enable - def post(self, trial_app): + @with_current_user + def post(self, current_user: Account, trial_app): app_model = trial_app file = request.files["file"] try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - # Get IDs before they might be detached from session app_id = app_model.id user_id = current_user.id @@ -358,7 +352,8 @@ class TrialChatAudioApi(TrialAppResource): class TrialChatTextApi(TrialAppResource): @console_ns.expect(console_ns.models[TextToSpeechRequest.__name__]) @trial_feature_enable - def post(self, trial_app): + @with_current_user + def post(self, current_user: Account, trial_app): app_model = trial_app try: request_data = TextToSpeechRequest.model_validate(console_ns.payload) @@ -366,8 +361,6 @@ class TrialChatTextApi(TrialAppResource): message_id = request_data.message_id text = request_data.text voice = request_data.voice - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") # Get IDs before they might be detached from session app_id = app_model.id @@ -405,7 +398,8 @@ class TrialChatTextApi(TrialAppResource): class TrialCompletionApi(TrialAppResource): @console_ns.expect(console_ns.models[CompletionRequest.__name__]) @trial_feature_enable - def post(self, trial_app): + @with_current_user + def post(self, current_user: Account, trial_app): app_model = trial_app if app_model.mode != "completion": raise NotCompletionAppError() @@ -417,9 +411,6 @@ class TrialCompletionApi(TrialAppResource): args["auto_generate_name"] = False try: - if not isinstance(current_user, Account): - raise ValueError("current_user must be an Account instance") - # Get IDs before they might be detached from session app_id = app_model.id user_id = current_user.id diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index b221db697ba..3b1b414150a 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,10 +1,9 @@ from flask_restx import Resource -from werkzeug.exceptions import Unauthorized from controllers.common.schema import register_response_schema_models from fields.base import ResponseModel from libs.helper import dump_response -from libs.login import current_user, login_required +from libs.login import current_account_with_tenant_optional, login_required from services.feature_service import ( FeatureModel, FeatureService, @@ -13,7 +12,12 @@ from services.feature_service import ( ) from . import console_ns -from .wraps import account_initialization_required, cloud_utm_record, setup_required, with_current_tenant_id +from .wraps import ( + account_initialization_required, + cloud_utm_record, + setup_required, + with_current_tenant_id, +) class TrialModelsResponse(ResponseModel): @@ -133,12 +137,6 @@ class SystemFeatureApi(Resource): Only non-sensitive configuration data should be returned by this endpoint. """ - # NOTE(QuantumGhost): ideally we should access `current_user.is_authenticated` - # without a try-catch. However, due to the implementation of user loader (the `load_user_from_request` - # in api/extensions/ext_login.py), accessing `current_user.is_authenticated` will - # raise `Unauthorized` exception if authentication token is not provided. - try: - is_authenticated = current_user.is_authenticated - except Unauthorized: - is_authenticated = False + current_user, _ = current_account_with_tenant_optional() + is_authenticated = current_user is not None return FeatureService.get_system_features(is_authenticated=is_authenticated).model_dump() diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 3805d0ff372..d862ba4ff45 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -17,7 +17,7 @@ from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager from extensions.ext_database import db from graphon.model_runtime.utils.encoders import jsonable_encoder -from libs.login import current_user, login_required +from libs.login import login_required from models.account import Account from models.provider_ids import TriggerProviderID from services.plugin.oauth_service import OAuthProxyService @@ -31,6 +31,8 @@ from ..wraps import ( edit_permission_required, is_admin_or_owner_required, setup_required, + with_current_tenant_id, + with_current_user, ) logger = logging.getLogger(__name__) @@ -77,12 +79,9 @@ class TriggerProviderIconApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - - return TriggerManager.get_trigger_plugin_icon(tenant_id=user.current_tenant_id, provider_id=provider) + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): + return TriggerManager.get_trigger_plugin_icon(tenant_id=tenant_id, provider_id=provider) @console_ns.route("/workspaces/current/triggers") @@ -90,12 +89,10 @@ class TriggerProviderListApi(Resource): @setup_required @login_required @account_initialization_required - def get(self): + @with_current_tenant_id + def get(self, tenant_id: str): """List all trigger providers for the current tenant""" - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - return jsonable_encoder(TriggerProviderService.list_trigger_providers(user.current_tenant_id)) + return jsonable_encoder(TriggerProviderService.list_trigger_providers(tenant_id)) @console_ns.route("/workspaces/current/trigger-provider//info") @@ -103,14 +100,10 @@ class TriggerProviderInfoApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): """Get info for a trigger provider""" - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - return jsonable_encoder( - TriggerProviderService.get_trigger_provider(user.current_tenant_id, TriggerProviderID(provider)) - ) + return jsonable_encoder(TriggerProviderService.get_trigger_provider(tenant_id, TriggerProviderID(provider))) @console_ns.route("/workspaces/current/trigger-provider//subscriptions/list") @@ -119,16 +112,14 @@ class TriggerSubscriptionListApi(Resource): @login_required @edit_permission_required @account_initialization_required - def get(self, provider: str): + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account, provider: str): """List all trigger subscriptions for the current tenant's provider""" - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - try: return jsonable_encoder( TriggerProviderService.list_trigger_provider_subscriptions( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=TriggerProviderID(provider), user=user, ) @@ -149,17 +140,16 @@ class TriggerSubscriptionBuilderCreateApi(Resource): @login_required @edit_permission_required @account_initialization_required - def post(self, provider: str): + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account, provider: str): """Add a new subscription instance for a trigger provider""" - user = current_user - assert user.current_tenant_id is not None - payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {}) try: credential_type = CredentialType.of(payload.credential_type) subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, provider_id=TriggerProviderID(provider), credential_type=credential_type, @@ -194,17 +184,16 @@ class TriggerSubscriptionBuilderVerifyApi(Resource): @login_required @edit_permission_required @account_initialization_required - def post(self, provider: str, subscription_builder_id: str): + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str): """Verify and update a subscription instance for a trigger provider""" - user = current_user - assert user.current_tenant_id is not None - payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_verify to prevent race conditions return TriggerSubscriptionBuilderService.update_and_verify_builder( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, @@ -226,17 +215,14 @@ class TriggerSubscriptionBuilderUpdateApi(Resource): @login_required @edit_permission_required @account_initialization_required - def post(self, provider: str, subscription_builder_id: str): + @with_current_tenant_id + def post(self, tenant_id: str, provider: str, subscription_builder_id: str): """Update a subscription instance for a trigger provider""" - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: return jsonable_encoder( TriggerSubscriptionBuilderService.update_trigger_subscription_builder( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, subscription_builder_updater=SubscriptionBuilderUpdater( @@ -262,10 +248,6 @@ class TriggerSubscriptionBuilderLogsApi(Resource): @account_initialization_required def get(self, provider: str, subscription_builder_id: str): """Get the request logs for a subscription instance for a trigger provider""" - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - try: logs = TriggerSubscriptionBuilderService.list_logs(subscription_builder_id) return jsonable_encoder({"logs": [log.model_dump(mode="json") for log in logs]}) @@ -283,15 +265,15 @@ class TriggerSubscriptionBuilderBuildApi(Resource): @login_required @edit_permission_required @account_initialization_required - def post(self, provider: str, subscription_builder_id: str): + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account, provider: str, subscription_builder_id: str): """Build a subscription instance for a trigger provider""" - user = current_user - assert user.current_tenant_id is not None payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) try: # Use atomic update_and_build to prevent race conditions TriggerSubscriptionBuilderService.update_and_build_builder( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, provider_id=TriggerProviderID(provider), subscription_builder_id=subscription_builder_id, @@ -316,15 +298,13 @@ class TriggerSubscriptionUpdateApi(Resource): @login_required @edit_permission_required @account_initialization_required - def post(self, subscription_id: str): + @with_current_tenant_id + def post(self, tenant_id: str, subscription_id: str): """Update a subscription instance""" - user = current_user - assert user.current_tenant_id is not None - request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {}) subscription = TriggerProviderService.get_subscription_by_id( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, subscription_id=subscription_id, ) if not subscription: @@ -341,7 +321,7 @@ class TriggerSubscriptionUpdateApi(Resource): manually_created = subscription.credential_type == CredentialType.UNAUTHORIZED if rename or manually_created: TriggerProviderService.update_trigger_subscription( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, subscription_id=subscription_id, name=request.name, properties=request.properties, @@ -351,7 +331,7 @@ class TriggerSubscriptionUpdateApi(Resource): # For the rest cases(API_KEY, OAUTH2) # we need to call third party provider(e.g. GitHub) to rebuild the subscription TriggerProviderService.rebuild_trigger_subscription( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, name=request.name, provider_id=provider_id, subscription_id=subscription_id, @@ -375,23 +355,21 @@ class TriggerSubscriptionDeleteApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, subscription_id: str): + @with_current_tenant_id + def post(self, tenant_id: str, subscription_id: str): """Delete a subscription instance""" - user = current_user - assert user.current_tenant_id is not None - try: with sessionmaker(db.engine).begin() as session: # Delete trigger provider subscription TriggerProviderService.delete_trigger_provider( session=session, - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, subscription_id=subscription_id, ) # Delete plugin triggers TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription( session=session, - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, subscription_id=subscription_id, ) return {"result": "success"} @@ -407,17 +385,14 @@ class TriggerOAuthAuthorizeApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider: str): + @with_current_user + @with_current_tenant_id + def get(self, tenant_id: str, user: Account, provider: str): """Initiate OAuth authorization flow for a trigger provider""" - user = current_user - assert isinstance(user, Account) - assert user.current_tenant_id is not None - try: provider_id = TriggerProviderID(provider) plugin_id = provider_id.plugin_id provider_name = provider_id.provider_name - tenant_id = user.current_tenant_id # Get OAuth client configuration oauth_client_params = TriggerProviderService.get_oauth_client( @@ -557,30 +532,28 @@ class TriggerOAuthClientManageApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def get(self, provider: str): + @with_current_tenant_id + def get(self, tenant_id: str, provider: str): """Get OAuth client configuration for a provider""" - user = current_user - assert user.current_tenant_id is not None - try: provider_id = TriggerProviderID(provider) # Get custom OAuth client params if exists custom_params = TriggerProviderService.get_custom_oauth_client_params( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=provider_id, ) # Check if custom client is enabled is_custom_enabled = TriggerProviderService.is_oauth_custom_client_enabled( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=provider_id, ) system_client_exists = TriggerProviderService.is_oauth_system_client_exists( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=provider_id, ) - provider_controller = TriggerManager.get_trigger_provider(user.current_tenant_id, provider_id) + provider_controller = TriggerManager.get_trigger_provider(tenant_id, provider_id) redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/trigger/callback" return jsonable_encoder( { @@ -603,17 +576,15 @@ class TriggerOAuthClientManageApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def post(self, provider: str): + @with_current_tenant_id + def post(self, tenant_id: str, provider: str): """Configure custom OAuth client for a provider""" - user = current_user - assert user.current_tenant_id is not None - payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {}) try: provider_id = TriggerProviderID(provider) return TriggerProviderService.save_custom_oauth_client_params( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=provider_id, client_params=payload.client_params, enabled=payload.enabled, @@ -629,16 +600,14 @@ class TriggerOAuthClientManageApi(Resource): @login_required @is_admin_or_owner_required @account_initialization_required - def delete(self, provider: str): + @with_current_tenant_id + def delete(self, tenant_id: str, provider: str): """Remove custom OAuth client configuration""" - user = current_user - assert user.current_tenant_id is not None - try: provider_id = TriggerProviderID(provider) return TriggerProviderService.delete_custom_oauth_client_params( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, provider_id=provider_id, ) except ValueError as e: @@ -657,16 +626,15 @@ class TriggerSubscriptionVerifyApi(Resource): @login_required @edit_permission_required @account_initialization_required - def post(self, provider: str, subscription_id: str): + @with_current_user + @with_current_tenant_id + def post(self, tenant_id: str, user: Account, provider: str, subscription_id: str): """Verify credentials for an existing subscription (edit mode only)""" - user = current_user - assert user.current_tenant_id is not None - verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {}) try: result = TriggerProviderService.verify_subscription_credentials( - tenant_id=user.current_tenant_id, + tenant_id=tenant_id, user_id=user.id, provider_id=TriggerProviderID(provider), subscription_id=subscription_id, diff --git a/api/libs/login.py b/api/libs/login.py index 12d0f53f2d6..bbb8ba1611c 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Concatenate, cast, overload from flask import Response, current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS +from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy from configs import dify_config @@ -48,6 +49,53 @@ def current_account_with_tenant() -> tuple[Account, str]: return user, user.current_tenant_id +def current_account_with_tenant_optional() -> tuple[Account | None, str | None]: + try: + user = _resolve_current_user() + except Unauthorized: + return None, None + + if not isinstance(user, Account): + return None, None + if not bool(getattr(user, "is_authenticated", False)): + return None, None + return user, user.current_tenant_id + + +def resolve_account_fallback( + current_user: Account | None = None, + current_tenant_id: str | None = None, + *, + fallback_tenant_id: str | None = None, +) -> tuple[Account, str]: + """ + If the provided current user and tenant ID is None, fallback to current_account_with_tenant. + This is useful for those service layers whose controllers are not migrated to use DI for + resolving current user yet. + + TODO: this should be removed after all ctrls (especially service API) are migrated + """ + if current_user is not None: + tenant_id = current_tenant_id or fallback_tenant_id + if tenant_id is None: + raise ValueError("current_tenant_id is required when current_user is provided.") + return current_user, tenant_id + return current_account_with_tenant() + + +def resolve_tenant_id_fallback(current_tenant_id: str | None = None) -> str: + """ + If the provided tenant ID is None, fallback to the tenant resolved from current_account_with_tenant. + This is useful for tenant-only service paths whose controllers are not all migrated to tenant injection yet. + + TODO: this should be removed after all ctrls (especially service API) are migrated + """ + if current_tenant_id is not None: + return current_tenant_id + _, tenant_id = current_account_with_tenant() + return tenant_id + + @overload def login_required[T, **P, R]( func: Callable[Concatenate[T, P], R], diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 672f309bac0..f9dcfd25c7f 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -7,7 +7,8 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField, Metad from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now -from libs.login import current_account_with_tenant +from libs.login import resolve_account_fallback +from models import Account from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding from models.enums import DatasetMetadataType from services.dataset_service import DocumentService @@ -21,11 +22,16 @@ logger = logging.getLogger(__name__) class MetadataService: @staticmethod - def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: + def create_metadata( + dataset_id: str, + metadata_args: MetadataArgs, + current_user: Account | None = None, # TODO: the service_api is not migrated yet + current_tenant_id: str | None = None, + ) -> DatasetMetadata: # check if metadata name is too long if len(metadata_args.name) > 255: raise ValueError("Metadata name cannot exceed 255 characters.") - current_user, current_tenant_id = current_account_with_tenant() + current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) # check if metadata name already exists if db.session.scalar( select(DatasetMetadata) @@ -52,14 +58,20 @@ class MetadataService: return metadata @staticmethod - def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore + def update_metadata_name( + dataset_id: str, + metadata_id: str, + name: str, + current_user: Account | None = None, + current_tenant_id: str | None = None, # TODO: the service_api is not migrated yet + ) -> DatasetMetadata | None: # check if metadata name is too long if len(name) > 255: raise ValueError("Metadata name cannot exceed 255 characters.") lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists - current_user, current_tenant_id = current_account_with_tenant() + current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) if db.session.scalar( select(DatasetMetadata) .where( @@ -107,6 +119,7 @@ class MetadataService: return metadata except Exception: logger.exception("Update metadata name failed") + return None finally: redis_client.delete(lock_key) @@ -217,7 +230,15 @@ class MetadataService: redis_client.delete(lock_key) @staticmethod - def update_documents_metadata(dataset: Dataset, metadata_args: MetadataOperationData): + def update_documents_metadata( + dataset: Dataset, + metadata_args: MetadataOperationData, + current_user: Account | None = None, # TODO: the service_api is not migrated yet + current_tenant_id: str | None = None, + ): + current_user, current_tenant_id = resolve_account_fallback( + current_user, current_tenant_id, fallback_tenant_id=dataset.tenant_id + ) for operation in metadata_args.operation_data: lock_key = f"document_metadata_lock_{operation.document_id}" try: @@ -248,7 +269,6 @@ class MetadataService: ) ) - current_user, current_tenant_id = current_account_with_tenant() for metadata_value in operation.metadata_list: # check if binding already exists if operation.partial_update: diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index 3ba7593be53..4e4cf2d19f5 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -21,7 +21,8 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.BUILTIN @override - def get_pipeline_templates(self, language: str) -> dict[str, Any]: + def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: + del current_tenant_id result = self.fetch_pipeline_templates_from_builtin(language) return result diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index ee73b0328f5..57dfefed2e0 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -4,7 +4,7 @@ import yaml from sqlalchemy import select from extensions.ext_database import db -from libs.login import current_account_with_tenant +from libs.login import resolve_tenant_id_fallback from models.dataset import PipelineCustomizedTemplate from services.rag_pipeline.pipeline_template.pipeline_template_base import PipelineTemplateRetrievalBase from services.rag_pipeline.pipeline_template.pipeline_template_type import PipelineTemplateType @@ -40,8 +40,8 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ @override - def get_pipeline_templates(self, language: str) -> dict[str, Any]: - _, current_tenant_id = current_account_with_tenant() + def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: + current_tenant_id = resolve_tenant_id_fallback(current_tenant_id) return self.fetch_pipeline_templates_from_customized(tenant_id=current_tenant_id, language=language) @override diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 9c94fdee2b0..0f6d0727c76 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -40,7 +40,8 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ @override - def get_pipeline_templates(self, language: str) -> dict[str, Any]: + def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: + del current_tenant_id return self.fetch_pipeline_templates_from_db(language) @override diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index 9cfb8f36aa7..84d8f5674bb 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -4,7 +4,7 @@ from typing import Any, Protocol class PipelineTemplateRetrievalBase(Protocol): """Interface for pipeline template retrieval.""" - def get_pipeline_templates(self, language: str) -> dict[str, Any]: ... + def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: ... def get_pipeline_template_detail(self, template_id: str) -> dict[str, Any] | None: ... diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 1be97c2888d..5cf46915ab0 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -25,7 +25,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return DatabasePipelineTemplateRetrieval.fetch_pipeline_template_detail_from_db(template_id) @override - def get_pipeline_templates(self, language: str) -> dict[str, Any]: + def get_pipeline_templates(self, language: str, current_tenant_id: str | None = None) -> dict[str, Any]: + del current_tenant_id try: return self.fetch_pipeline_templates_from_dify_official(language) except Exception as e: diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index fd02a44995d..abab174b3d9 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -8,7 +8,6 @@ from datetime import UTC, datetime from typing import Any, cast from uuid import uuid4 -from flask_login import current_user from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -54,6 +53,7 @@ from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_htt from graphon.runtime import VariablePool from graphon.variables.variables import Variable, VariableBase from libs.infinite_scroll_pagination import InfiniteScrollPagination +from libs.login import resolve_account_fallback, resolve_tenant_id_fallback from models import Account from models.dataset import ( # type: ignore Dataset, @@ -104,11 +104,16 @@ class RagPipelineService: self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) @classmethod - def get_pipeline_templates(cls, type: str = "built-in", language: str = "en-US") -> dict[str, Any]: + def get_pipeline_templates( + cls, + type: str = "built-in", + language: str = "en-US", + current_tenant_id: str | None = None, + ) -> dict[str, Any]: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - result = retrieval_instance.get_pipeline_templates(language) + result = retrieval_instance.get_pipeline_templates(language, current_tenant_id) if not result.get("pipeline_templates") and language != "en-US": template_retrieval = PipelineTemplateRetrievalFactory.get_built_in_pipeline_template_retrieval() result = template_retrieval.fetch_pipeline_templates_from_builtin("en-US") @@ -116,7 +121,7 @@ class RagPipelineService: else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - result = retrieval_instance.get_pipeline_templates(language) + result = retrieval_instance.get_pipeline_templates(language, current_tenant_id) return result @classmethod @@ -146,17 +151,24 @@ class RagPipelineService: return customized_result @classmethod - def update_customized_pipeline_template(cls, template_id: str, template_info: PipelineTemplateInfoEntity): + def update_customized_pipeline_template( + cls, + template_id: str, + template_info: PipelineTemplateInfoEntity, + current_user: Account | None = None, + current_tenant_id: str | None = None, + ): """ Update pipeline template. :param template_id: template id :param template_info: template info """ + current_user, current_tenant_id = resolve_account_fallback(current_user, current_tenant_id) customized_template: PipelineCustomizedTemplate | None = db.session.scalar( select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.id == template_id, - PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.tenant_id == current_tenant_id, ) .limit(1) ) @@ -169,7 +181,7 @@ class RagPipelineService: select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.name == template_name, - PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.tenant_id == current_tenant_id, PipelineCustomizedTemplate.id != template_id, ) .limit(1) @@ -184,15 +196,16 @@ class RagPipelineService: return customized_template @classmethod - def delete_customized_pipeline_template(cls, template_id: str): + def delete_customized_pipeline_template(cls, template_id: str, current_tenant_id: str | None = None): """ Delete customized pipeline template. """ + current_tenant_id = resolve_tenant_id_fallback(current_tenant_id) customized_template: PipelineCustomizedTemplate | None = db.session.scalar( select(PipelineCustomizedTemplate) .where( PipelineCustomizedTemplate.id == template_id, - PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, + PipelineCustomizedTemplate.tenant_id == current_tenant_id, ) .limit(1) ) @@ -1174,10 +1187,17 @@ class RagPipelineService: return list(node_executions) @classmethod - def publish_customized_pipeline_template(cls, pipeline_id: str, args: dict[str, Any]): + def publish_customized_pipeline_template( + cls, + pipeline_id: str, + args: dict[str, Any], + current_user: Account | None = None, + current_tenant_id: str | None = None, + ): """ Publish customized pipeline template """ + current_user, _ = resolve_account_fallback(current_user, current_tenant_id) pipeline = db.session.get(Pipeline, pipeline_id) if not pipeline: raise ValueError("Pipeline not found") @@ -1357,7 +1377,7 @@ class RagPipelineService: return [] return marketplace.batch_fetch_plugin_by_ids(plugin_ids) - def get_recommended_plugins(self, type: str) -> dict[str, Any]: + def get_recommended_plugins(self, type: str, current_user: Account, current_tenant_id: str) -> dict[str, Any]: # Query active recommended plugins stmt = select(PipelineRecommendedPlugin).where(PipelineRecommendedPlugin.active == True) if type and type != "all": @@ -1375,7 +1395,7 @@ class RagPipelineService: plugin_ids = [plugin.plugin_id for plugin in pipeline_recommended_plugins] providers = BuiltinToolManageService.list_builtin_tools( user_id=current_user.id, - tenant_id=current_user.current_tenant_id, + tenant_id=current_tenant_id, ) providers_map = {provider.plugin_id: provider.to_dict() for provider in providers} diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 027356b6278..75b0d3c5002 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Callable +from inspect import unwrap from typing import cast from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -20,8 +21,8 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import ( PipelineTemplateListApi, PublishCustomizedPipelineTemplateApi, ) +from models.account import Account from models.dataset import PipelineCustomizedTemplate -from tests.test_containers_integration_tests.controllers.console.helpers import unwrap class TestPipelineTemplateListApi: @@ -53,7 +54,7 @@ class TestPipelineTemplateListApi: return_value=templates, ), ): - response, status = method(api) + response, status = method(api, str(uuid4())) assert status == 200 assert response == { @@ -147,6 +148,9 @@ class TestCustomizedPipelineTemplateApi: def test_patch_success(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() method = unwrap(api.patch) + account = Account(name="Test User", email="test@example.com") + account.id = str(uuid4()) + tenant_id = str(uuid4()) payload = { "name": "Template", @@ -161,15 +165,18 @@ class TestCustomizedPipelineTemplateApi: "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template" ) as update_mock, ): - response, status = method(api, "tpl-1") + response, status = method(api, tenant_id, account, "tpl-1") update_mock.assert_called_once() + assert update_mock.call_args.args[2] is account + assert update_mock.call_args.args[3] == tenant_id assert status == 204 assert response == "" def test_delete_success(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() method = unwrap(api.delete) + tenant_id = str(uuid4()) with ( app.test_request_context("/"), @@ -177,9 +184,9 @@ class TestCustomizedPipelineTemplateApi: "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template" ) as delete_mock, ): - response, status = method(api, "tpl-1") + response, status = method(api, tenant_id, "tpl-1") - delete_mock.assert_called_once_with("tpl-1") + delete_mock.assert_called_once_with("tpl-1", tenant_id) assert status == 204 assert response == "" @@ -227,6 +234,9 @@ class TestPublishCustomizedPipelineTemplateApi: def test_post_success(self, app: Flask) -> None: api = PublishCustomizedPipelineTemplateApi() method = unwrap(api.post) + account = Account(name="Test User", email="test@example.com") + account.id = str(uuid4()) + tenant_id = str(uuid4()) payload = { "name": "Template", @@ -244,8 +254,10 @@ class TestPublishCustomizedPipelineTemplateApi: return_value=service, ), ): - response, status = method(api, "pipeline-1") + response, status = method(api, tenant_id, account, "pipeline-1") service.publish_customized_pipeline_template.assert_called_once() + assert service.publish_customized_pipeline_template.call_args.args[2] is account + assert service.publish_customized_pipeline_template.call_args.args[3] == tenant_id assert status == 204 assert response == "" diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index d1d8e6fd757..bdec903ef33 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -5,7 +5,6 @@ from __future__ import annotations import json from datetime import datetime from inspect import unwrap -from types import SimpleNamespace from typing import TypedDict, Unpack from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -35,12 +34,15 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineTaskStopApi, RagPipelineTransformApi, RagPipelineWorkflowLastRunApi, + RagPipelineWorkflowRunNodeExecutionListApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from graphon.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.account import Account, TenantAccountRole from models.dataset import Pipeline -from models.workflow import Workflow +from models.enums import CreatorUserRole +from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -71,7 +73,7 @@ class WorkflowFactoryPayload(TypedDict): created_by: str created_at: datetime updated_by: str | None - updated_at: datetime + updated_at: datetime | None environment_variables: list[WorkflowVariablePayload] conversation_variables: list[WorkflowVariablePayload] rag_pipeline_variables: list[WorkflowVariablePayload] @@ -90,39 +92,69 @@ class WorkflowFactoryOverrides(TypedDict, total=False): created_by: str created_at: datetime updated_by: str | None - updated_at: datetime + updated_at: datetime | None environment_variables: list[WorkflowVariablePayload] conversation_variables: list[WorkflowVariablePayload] rag_pipeline_variables: list[WorkflowVariablePayload] -def make_node_execution(**overrides: object) -> SimpleNamespace: - payload: dict[str, object] = { +class NodeExecutionOverrides(TypedDict, total=False): + id: str + tenant_id: str + app_id: str + workflow_id: str + workflow_run_id: str | None + index: int + predecessor_node_id: str | None + node_execution_id: str | None + node_id: str + node_type: str + title: str + inputs: str | None + process_data: str | None + outputs: str | None + status: WorkflowNodeExecutionStatus + error: str | None + elapsed_time: float + execution_metadata: str | None + created_at: datetime + created_by_role: CreatorUserRole + created_by: str + finished_at: datetime | None + + +def make_node_execution(**overrides: Unpack[NodeExecutionOverrides]) -> WorkflowNodeExecutionModel: + payload: NodeExecutionOverrides = { "id": "node-exec-1", + "tenant_id": DEFAULT_WORKFLOW_TENANT_ID, + "app_id": DEFAULT_WORKFLOW_APP_ID, + "workflow_id": "workflow-1", + "workflow_run_id": None, "index": 1, "predecessor_node_id": None, + "node_execution_id": None, "node_id": "node1", "node_type": "start", "title": "Start", - "inputs_dict": {"query": "hello"}, - "process_data_dict": {}, - "outputs_dict": {"answer": "world"}, - "status": "succeeded", + "inputs": json.dumps({"query": "hello"}), + "process_data": json.dumps({}), + "outputs": json.dumps({"answer": "world"}), + "status": WorkflowNodeExecutionStatus.SUCCEEDED, "error": None, "elapsed_time": 1.0, - "execution_metadata_dict": {}, - "extras": {}, + "execution_metadata": json.dumps({}), "created_at": datetime(2026, 1, 1, 0, 0, 0), - "created_by_role": "account", - "created_by_account": None, - "created_by_end_user": None, + "created_by_role": CreatorUserRole.ACCOUNT, + "created_by": DEFAULT_WORKFLOW_CREATED_BY, "finished_at": datetime(2026, 1, 1, 0, 0, 1), - "inputs_truncated": False, - "outputs_truncated": False, - "process_data_truncated": False, } payload.update(overrides) - return SimpleNamespace(**payload) + execution = WorkflowNodeExecutionModel( + triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN, + **payload, + ) + execution.offload_data = [] + return execution def default_workflow_payload() -> WorkflowFactoryPayload: @@ -274,7 +306,10 @@ class TestDraftWorkflowApi: pipeline = make_pipeline() user = make_account(id="account-1") - workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1)) + workflow = make_workflow( + graph=json.dumps({"nodes": [{"id": "restored"}], "edges": []}), + created_at=datetime(2024, 1, 1), + ) service = MagicMock() service.restore_published_workflow_to_draft.return_value = workflow @@ -289,7 +324,7 @@ class TestDraftWorkflowApi: result = method(api, user, pipeline, "published-workflow") assert result["result"] == "success" - assert result["hash"] == "restored-hash" + assert result["hash"] == workflow.unique_hash def test_restore_published_workflow_to_draft_not_found(self, app: Flask) -> None: api = RagPipelineDraftWorkflowRestoreApi() @@ -515,10 +550,7 @@ class TestPublishedPipelineApis: user = make_account(id="u1") - workflow = MagicMock( - id=str(uuid4()), - created_at=naive_utc_now(), - ) + workflow = make_workflow(id=str(uuid4()), created_at=naive_utc_now()) service = MagicMock() service.publish_workflow.return_value = workflow @@ -576,6 +608,8 @@ class TestMiscApis: service = MagicMock() service.get_recommended_plugins.return_value = [{"id": "p1"}] + user = make_account() + tenant_id = "tenant-1" with ( app.test_request_context("/?type=all"), @@ -584,8 +618,9 @@ class TestMiscApis: return_value=service, ), ): - result = method(api) + result = method(api, tenant_id, user) assert result == [{"id": "p1"}] + service.get_recommended_plugins.assert_called_once_with("all", user, tenant_id) class TestPublishedRagPipelineRunApi: @@ -814,7 +849,7 @@ class TestRagPipelineWorkflowLastRunApi: method = unwrap(api.get) pipeline = make_pipeline() - workflow = MagicMock() + workflow = make_workflow() node_exec = make_node_execution() service = MagicMock() @@ -853,6 +888,42 @@ class TestRagPipelineWorkflowLastRunApi: method(api, pipeline, "node1") +class TestRagPipelineWorkflowRunNodeExecutionListApi: + @pytest.fixture + def app(self, flask_app_with_containers: Flask) -> Flask: + return flask_app_with_containers + + def test_get_node_executions_passes_current_user(self, app: Flask) -> None: + api = RagPipelineWorkflowRunNodeExecutionListApi() + method = unwrap(api.get) + + user = make_account() + pipeline = make_pipeline() + run_id = uuid4() + node_exec = make_node_execution(workflow_run_id=str(run_id)) + + service = MagicMock() + service.get_rag_pipeline_workflow_run_node_executions.return_value = [node_exec] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, user, pipeline, run_id) + + service.get_rag_pipeline_workflow_run_node_executions.assert_called_once_with( + pipeline=pipeline, + run_id=str(run_id), + user=user, + ) + assert result["data"][0]["id"] == "node-exec-1" + assert result["data"][0]["inputs"] == {"query": "hello"} + assert result["data"][0]["outputs"] == {"answer": "world"} + + class TestRagPipelineDatasourceVariableApi: @pytest.fixture def app(self, flask_app_with_containers: Flask) -> Flask: diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index b5f5917ee99..3d5fce4b6ca 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -2,7 +2,10 @@ from __future__ import annotations -from unittest.mock import MagicMock, patch +from dataclasses import dataclass +from inspect import unwrap +from typing import cast +from unittest.mock import patch import pytest from flask import Flask @@ -10,85 +13,103 @@ from werkzeug.exceptions import NotFound import controllers.console.explore.conversation as conversation_module from controllers.console.explore.error import NotChatAppError +from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account -from models.model import AppMode +from models.enums import ConversationFromSource, ConversationStatus +from models.model import App, AppMode, Conversation, InstalledApp from services.errors.conversation import ( ConversationNotExistsError, LastConversationNotExistsError, ) -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - -class FakeConversation: - def __init__(self, cid): - self.id = cid - self.name = "test" - self.inputs = {} - self.status = "normal" - self.introduction = "" +@dataclass +class InstalledAppCarrier: + app: App | None @pytest.fixture -def chat_app(): - app_model = MagicMock(mode=AppMode.CHAT, id="app-id") - return MagicMock(app=app_model) +def chat_app() -> InstalledApp: + app_model = App( + tenant_id="tenant-1", + name="Chat App", + mode=AppMode.CHAT, + enable_site=True, + enable_api=False, + ) + app_model.id = "app-id" + return cast(InstalledApp, InstalledAppCarrier(app=app_model)) @pytest.fixture -def non_chat_app(): - app_model = MagicMock(mode=AppMode.COMPLETION) - return MagicMock(app=app_model) +def non_chat_app() -> InstalledApp: + app_model = App( + tenant_id="tenant-1", + name="Completion App", + mode=AppMode.COMPLETION, + enable_site=True, + enable_api=False, + ) + app_model.id = "app-id" + return cast(InstalledApp, InstalledAppCarrier(app=app_model)) + + +def make_conversation(*, id: str) -> Conversation: + conversation = Conversation( + app_id="app-id", + mode=AppMode.CHAT, + name="test", + from_source=ConversationFromSource.API, + ) + conversation.id = id + conversation.inputs = {} + conversation.status = ConversationStatus.NORMAL + conversation.introduction = "" + return conversation @pytest.fixture -def user(): - user = MagicMock(spec=Account) +def user() -> Account: + user = Account(name="User", email="user.com") user.id = "uid" return user class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_success(self, app: Flask, chat_app, user): + def test_get_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationListApi() method = unwrap(api.get) - pagination = MagicMock( + pagination = InfiniteScrollPagination( + data=[make_conversation(id="c1"), make_conversation(id="c2")], limit=20, has_more=False, - data=[FakeConversation("c1"), FakeConversation("c2")], ) with ( app.test_request_context("/?limit=20"), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.WebConversationService, "pagination_by_last_id", return_value=pagination, ), ): - result = method(chat_app) + result = method(api, user, chat_app) assert result["limit"] == 20 assert result["has_more"] is False assert len(result["data"]) == 2 - def test_last_conversation_not_exists(self, app: Flask, chat_app, user): + def test_last_conversation_not_exists(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationListApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.WebConversationService, "pagination_by_last_id", @@ -96,47 +117,45 @@ class TestConversationListApi: ), ): with pytest.raises(NotFound): - method(chat_app) + method(api, user, chat_app) - def test_wrong_app_mode(self, app: Flask, non_chat_app): + def test_wrong_app_mode(self, app: Flask, non_chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationListApi() method = unwrap(api.get) with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(non_chat_app) + method(api, user, non_chat_app) class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_delete_success(self, app: Flask, chat_app, user): + def test_delete_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationApi() method = unwrap(api.delete) with ( app.test_request_context("/"), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.ConversationService, "delete", ), ): - result = method(chat_app, "cid") + result = method(api, user, chat_app, "cid") body, status = result assert status == 204 assert body == "" - def test_delete_not_found(self, app: Flask, chat_app, user): + def test_delete_not_found(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationApi() method = unwrap(api.delete) with ( app.test_request_context("/"), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.ConversationService, "delete", @@ -144,48 +163,46 @@ class TestConversationApi: ), ): with pytest.raises(NotFound): - method(chat_app, "cid") + method(api, user, chat_app, "cid") - def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): + def test_delete_wrong_app_mode(self, app: Flask, non_chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationApi() method = unwrap(api.delete) with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(non_chat_app, "cid") + method(api, user, non_chat_app, "cid") class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_rename_success(self, app: Flask, chat_app, user): + def test_rename_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationRenameApi() method = unwrap(api.post) - conversation = FakeConversation("cid") + conversation = make_conversation(id="cid") with ( app.test_request_context("/", json={"name": "new"}), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.ConversationService, "rename", return_value=conversation, ), ): - result = method(chat_app, "cid") + result = method(api, user, chat_app, "cid") assert result["id"] == "cid" - def test_rename_not_found(self, app: Flask, chat_app, user): + def test_rename_not_found(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationRenameApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"name": "new"}), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.ConversationService, "rename", @@ -193,48 +210,46 @@ class TestConversationRenameApi: ), ): with pytest.raises(NotFound): - method(chat_app, "cid") + method(api, user, chat_app, "cid") class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_pin_success(self, app: Flask, chat_app, user): + def test_pin_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationPinApi() method = unwrap(api.patch) with ( app.test_request_context("/"), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.WebConversationService, "pin", ), ): - result = method(chat_app, "cid") + result = method(api, user, chat_app, "cid") assert result == {"result": "success"} class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_unpin_success(self, app: Flask, chat_app, user): + def test_unpin_success(self, app: Flask, chat_app: InstalledApp, user: Account) -> None: api = conversation_module.ConversationUnPinApi() method = unwrap(api.patch) with ( app.test_request_context("/"), - patch.object(conversation_module, "current_user", user), patch.object( conversation_module.WebConversationService, "unpin", ), ): - result = method(chat_app, "cid") + result = method(api, user, chat_app, "cid") assert result == {"result": "success"} diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index 6c74b3193b9..6684381880c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -2,6 +2,7 @@ from __future__ import annotations +from inspect import unwrap from unittest.mock import MagicMock, patch import pytest @@ -31,123 +32,110 @@ from core.plugin.entities.plugin_daemon import CredentialType from models.account import Account -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - -def mock_user(): - user = MagicMock(spec=Account) +def mock_user() -> Account: + user = Account(name="User", email="user.com") user.id = "u1" - user.current_tenant_id = "t1" return user class TestTriggerProviderApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_icon_success(self, app: Flask): + def test_icon_success(self, app: Flask) -> None: api = TriggerProviderIconApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon", return_value="icon", ), ): - assert method(api, "github") == "icon" + assert method(api, "t1", "github") == "icon" - def test_list_providers(self, app: Flask): + def test_list_providers(self, app: Flask) -> None: api = TriggerProviderListApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers", return_value=[], ), ): - assert method(api) == [] + assert method(api, "t1") == [] - def test_provider_info(self, app: Flask): + def test_provider_info(self, app: Flask) -> None: api = TriggerProviderInfoApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider", return_value={"id": "p1"}, ), ): - assert method(api, "github") == {"id": "p1"} + assert method(api, "t1", "github") == {"id": "p1"} class TestTriggerSubscriptionListApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_list_success(self, app: Flask): + def test_list_success(self, app: Flask) -> None: api = TriggerSubscriptionListApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", return_value=[], ), ): - assert method(api, "github") == [] + assert method(api, "t1", mock_user(), "github") == [] - def test_list_invalid_provider(self, app: Flask): + def test_list_invalid_provider(self, app: Flask) -> None: api = TriggerSubscriptionListApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", side_effect=ValueError("bad"), ), ): - result, status = method(api, "bad") + result, status = method(api, "t1", mock_user(), "bad") assert status == 404 class TestTriggerSubscriptionBuilderApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_create_builder(self, app: Flask): + def test_create_builder(self, app: Flask) -> None: api = TriggerSubscriptionBuilderCreateApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", return_value={"id": "b1"}, ), ): - result = method(api, "github") + result = method(api, "t1", mock_user(), "github") assert "subscription_builder" in result - def test_get_builder(self, app: Flask): + def test_get_builder(self, app: Flask) -> None: api = TriggerSubscriptionBuilderGetApi() method = unwrap(api.get) @@ -160,50 +148,47 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_verify_builder(self, app: Flask): + def test_verify_builder(self, app: Flask) -> None: api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"credentials": {"a": 1}}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", return_value={"ok": True}, ), ): - assert method(api, "github", "b1") == {"ok": True} + assert method(api, "t1", mock_user(), "github", "b1") == {"ok": True} - def test_verify_builder_error(self, app: Flask): + def test_verify_builder_error(self, app: Flask) -> None: api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"credentials": {}}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", side_effect=Exception("err"), ), ): with pytest.raises(ValueError): - method(api, "github", "b1") + method(api, "t1", mock_user(), "github", "b1") - def test_update_builder(self, app: Flask): + def test_update_builder(self, app: Flask) -> None: api = TriggerSubscriptionBuilderUpdateApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"name": "n"}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", return_value={"id": "b1"}, ), ): - assert method(api, "github", "b1") == {"id": "b1"} + assert method(api, "t1", "github", "b1") == {"id": "b1"} - def test_logs(self, app: Flask): + def test_logs(self, app: Flask) -> None: api = TriggerSubscriptionBuilderLogsApi() method = unwrap(api.get) @@ -212,7 +197,6 @@ class TestTriggerSubscriptionBuilderApis: with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs", return_value=[log], @@ -220,27 +204,26 @@ class TestTriggerSubscriptionBuilderApis: ): assert "logs" in method(api, "github", "b1") - def test_build(self, app: Flask): + def test_build(self, app: Flask) -> None: api = TriggerSubscriptionBuilderBuildApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"name": "x"}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder", return_value=None, ), ): - assert method(api, "github", "b1") == 200 + assert method(api, "t1", mock_user(), "github", "b1") == 200 class TestTriggerSubscriptionCrud: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_update_rename_only(self, app: Flask): + def test_update_rename_only(self, app: Flask) -> None: api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -250,43 +233,40 @@ class TestTriggerSubscriptionCrud: with ( app.test_request_context("/", json={"name": "x"}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", return_value=sub, ), patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"), ): - assert method(api, "s1") == 200 + assert method(api, "t1", "s1") == 200 - def test_update_not_found(self, app: Flask): + def test_update_not_found(self, app: Flask) -> None: api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"name": "x"}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", return_value=None, ), ): with pytest.raises(NotFoundError): - method(api, "x") + method(api, "t1", "x") - def test_update_rebuild(self, app: Flask): + def test_update_rebuild(self, app: Flask) -> None: api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) sub = MagicMock() sub.provider_id = "github" sub.credential_type = CredentialType.OAUTH2 - sub.credentials = {} - sub.parameters = {} + sub.credentials = {"token": "old"} + sub.parameters = {"repo": "demo"} with ( app.test_request_context("/", json={"credentials": {}}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", return_value=sub, @@ -295,9 +275,9 @@ class TestTriggerSubscriptionCrud: "controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription" ), ): - assert method(api, "s1") == 200 + assert method(api, "t1", "s1") == 200 - def test_delete_subscription(self, app: Flask): + def test_delete_subscription(self, app: Flask) -> None: api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -305,7 +285,6 @@ class TestTriggerSubscriptionCrud: with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch("controllers.console.workspace.trigger_providers.db") as mock_db, patch("controllers.console.workspace.trigger_providers.sessionmaker") as mock_session_cls, patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"), @@ -316,17 +295,16 @@ class TestTriggerSubscriptionCrud: mock_db.engine = MagicMock() mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session - result = method(api, "sub1") + result = method(api, "t1", "sub1") assert result["result"] == "success" - def test_delete_subscription_value_error(self, app: Flask): + def test_delete_subscription_value_error(self, app: Flask) -> None: api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch("controllers.console.workspace.trigger_providers.db") as mock_db, patch("controllers.console.workspace.trigger_providers.sessionmaker") as session_cls, patch( @@ -338,21 +316,20 @@ class TestTriggerSubscriptionCrud: session_cls.return_value.begin.return_value.__enter__.return_value = MagicMock() with pytest.raises(BadRequest): - method(api, "sub1") + method(api, "t1", "sub1") class TestTriggerOAuthApis: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_oauth_authorize_success(self, app: Flask): + def test_oauth_authorize_success(self, app: Flask) -> None: api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", return_value={"a": 1}, @@ -370,25 +347,24 @@ class TestTriggerOAuthApis: return_value=MagicMock(authorization_url="url"), ), ): - resp = method(api, "github") + resp = method(api, "t1", mock_user(), "github") assert resp.status_code == 200 - def test_oauth_authorize_no_client(self, app: Flask): + def test_oauth_authorize_no_client(self, app: Flask) -> None: api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", return_value=None, ), ): with pytest.raises(NotFoundError): - method(api, "github") + method(api, "t1", mock_user(), "github") - def test_oauth_callback_forbidden(self, app: Flask): + def test_oauth_callback_forbidden(self, app: Flask) -> None: api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -396,7 +372,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_success(self, app: Flask): + def test_oauth_callback_success(self, app: Flask) -> None: api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -426,7 +402,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 302 - def test_oauth_callback_no_oauth_client(self, app: Flask): + def test_oauth_callback_no_oauth_client(self, app: Flask) -> None: api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -450,7 +426,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_empty_credentials(self, app: Flask): + def test_oauth_callback_empty_credentials(self, app: Flask) -> None: api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -481,16 +457,15 @@ class TestTriggerOAuthApis: class TestTriggerOAuthClientManageApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_get_client(self, app: Flask): + def test_get_client(self, app: Flask) -> None: api = TriggerOAuthClientManageApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params", return_value={}, @@ -508,84 +483,79 @@ class TestTriggerOAuthClientManageApi: return_value=MagicMock(get_oauth_client_schema=lambda: {}), ), ): - result = method(api, "github") + result = method(api, "t1", "github") assert "configured" in result - def test_post_client(self, app: Flask): + def test_post_client(self, app: Flask) -> None: api = TriggerOAuthClientManageApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"enabled": True}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", return_value={"ok": True}, ), ): - assert method(api, "github") == {"ok": True} + assert method(api, "t1", "github") == {"ok": True} - def test_delete_client(self, app: Flask): + def test_delete_client(self, app: Flask) -> None: api = TriggerOAuthClientManageApi() method = unwrap(api.delete) with ( app.test_request_context("/"), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params", return_value={"ok": True}, ), ): - assert method(api, "github") == {"ok": True} + assert method(api, "t1", "github") == {"ok": True} - def test_oauth_client_post_value_error(self, app: Flask): + def test_oauth_client_post_value_error(self, app: Flask) -> None: api = TriggerOAuthClientManageApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"enabled": True}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", side_effect=ValueError("bad"), ), ): with pytest.raises(BadRequest): - method(api, "github") + method(api, "t1", "github") class TestTriggerSubscriptionVerifyApi: @pytest.fixture - def app(self, flask_app_with_containers: Flask): + def app(self, flask_app_with_containers: Flask) -> Flask: return flask_app_with_containers - def test_verify_success(self, app: Flask): + def test_verify_success(self, app: Flask) -> None: api = TriggerSubscriptionVerifyApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"credentials": {}}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", return_value={"ok": True}, ), ): - assert method(api, "github", "s1") == {"ok": True} + assert method(api, "t1", mock_user(), "github", "s1") == {"ok": True} @pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")]) - def test_verify_errors(self, app: Flask, raised_exception): + def test_verify_errors(self, app: Flask, raised_exception: Exception) -> None: api = TriggerSubscriptionVerifyApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"credentials": {}}), - patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), patch( "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", side_effect=raised_exception, ), ): with pytest.raises(BadRequest): - method(api, "github", "s1") + method(api, "t1", mock_user(), "github", "s1") diff --git a/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py index 8fc1809a467..8f126e1cff0 100644 --- a/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py +++ b/api/tests/test_containers_integration_tests/services/rag_pipeline/test_rag_pipeline_service_db.py @@ -11,19 +11,29 @@ Covers: """ from collections.abc import Generator -from types import SimpleNamespace from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session, sessionmaker +from models import Account, Tenant from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate from models.enums import DataSourceType from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService +def _make_account(account_id: str, tenant_id: str) -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + tenant = Tenant(name="Test Tenant") + tenant.id = tenant_id + account._current_tenant = tenant + return account + + class TestRagPipelineServiceGetPipeline: """Integration tests for RagPipelineService.get_pipeline.""" @@ -32,7 +42,7 @@ class TestRagPipelineServiceGetPipeline: yield db_session_with_containers.rollback() - def _make_service(self, flask_app_with_containers) -> RagPipelineService: + def _make_service(self, flask_app_with_containers: Flask) -> RagPipelineService: with ( patch( "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository", @@ -72,7 +82,7 @@ class TestRagPipelineServiceGetPipeline: return dataset def test_get_pipeline_raises_when_dataset_not_found( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ) -> None: """get_pipeline raises ValueError when dataset does not exist.""" service = self._make_service(flask_app_with_containers) @@ -81,7 +91,7 @@ class TestRagPipelineServiceGetPipeline: service.get_pipeline(tenant_id=str(uuid4()), dataset_id=str(uuid4())) def test_get_pipeline_raises_when_pipeline_not_found( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ) -> None: """get_pipeline raises ValueError when dataset exists but has no linked pipeline.""" tenant_id = str(uuid4()) @@ -95,7 +105,7 @@ class TestRagPipelineServiceGetPipeline: service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset.id) def test_get_pipeline_returns_pipeline_when_found( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ) -> None: """get_pipeline returns the Pipeline when both Dataset and Pipeline exist.""" tenant_id = str(uuid4()) @@ -139,43 +149,44 @@ class TestUpdateCustomizedPipelineTemplate: db_session.flush() return template - def test_update_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + def test_update_template_succeeds( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ) -> None: """update_customized_pipeline_template updates name and description.""" tenant_id = str(uuid4()) created_by = str(uuid4()) template = self._create_template(db_session_with_containers, tenant_id, created_by) db_session_with_containers.flush() - fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + account = _make_account(created_by, tenant_id) - with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): - info = PipelineTemplateInfoEntity( - name="Updated Name", - description="Updated description", - icon_info=IconInfo(icon="🔥"), - ) - result = RagPipelineService.update_customized_pipeline_template(template.id, info) + info = PipelineTemplateInfoEntity( + name="Updated Name", + description="Updated description", + icon_info=IconInfo(icon="🔥"), + ) + result = RagPipelineService.update_customized_pipeline_template(template.id, info, account, tenant_id) assert result.name == "Updated Name" assert result.description == "Updated description" def test_update_template_raises_when_not_found( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ) -> None: """update_customized_pipeline_template raises ValueError when template doesn't exist.""" - fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + tenant_id = str(uuid4()) + account = _make_account(str(uuid4()), tenant_id) - with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): - info = PipelineTemplateInfoEntity( - name="New Name", - description="desc", - icon_info=IconInfo(icon="📄"), - ) - with pytest.raises(ValueError, match="Customized pipeline template not found"): - RagPipelineService.update_customized_pipeline_template(str(uuid4()), info) + info = PipelineTemplateInfoEntity( + name="New Name", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.update_customized_pipeline_template(str(uuid4()), info, account, tenant_id) def test_update_template_raises_on_duplicate_name( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ) -> None: """update_customized_pipeline_template raises ValueError when new name already exists.""" tenant_id = str(uuid4()) @@ -184,16 +195,15 @@ class TestUpdateCustomizedPipelineTemplate: self._create_template(db_session_with_containers, tenant_id, created_by, name="Duplicate") db_session_with_containers.flush() - fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + account = _make_account(created_by, tenant_id) - with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): - info = PipelineTemplateInfoEntity( - name="Duplicate", - description="desc", - icon_info=IconInfo(icon="📄"), - ) - with pytest.raises(ValueError, match="Template name is already exists"): - RagPipelineService.update_customized_pipeline_template(template1.id, info) + info = PipelineTemplateInfoEntity( + name="Duplicate", + description="desc", + icon_info=IconInfo(icon="📄"), + ) + with pytest.raises(ValueError, match="Template name is already exists"): + RagPipelineService.update_customized_pipeline_template(template1.id, info, account, tenant_id) class TestDeleteCustomizedPipelineTemplate: @@ -221,7 +231,9 @@ class TestDeleteCustomizedPipelineTemplate: db_session.flush() return template - def test_delete_template_succeeds(self, db_session_with_containers: Session, flask_app_with_containers) -> None: + def test_delete_template_succeeds( + self, db_session_with_containers: Session, flask_app_with_containers: Flask + ) -> None: """delete_customized_pipeline_template removes the template from the DB.""" tenant_id = str(uuid4()) created_by = str(uuid4()) @@ -229,27 +241,23 @@ class TestDeleteCustomizedPipelineTemplate: template_id = template.id db_session_with_containers.flush() - fake_user = SimpleNamespace(id=created_by, current_tenant_id=tenant_id) + RagPipelineService.delete_customized_pipeline_template(template_id, tenant_id) - with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): - RagPipelineService.delete_customized_pipeline_template(template_id) + # Verify the record is deleted within the same context + from sqlalchemy import select - # Verify the record is deleted within the same context - from sqlalchemy import select + from extensions.ext_database import db as ext_db - from extensions.ext_database import db as ext_db - - remaining = ext_db.session.scalar( - select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id) - ) - assert remaining is None + remaining = ext_db.session.scalar( + select(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id) + ) + assert remaining is None def test_delete_template_raises_when_not_found( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ) -> None: """delete_customized_pipeline_template raises ValueError when template doesn't exist.""" - fake_user = SimpleNamespace(id=str(uuid4()), current_tenant_id=str(uuid4())) + tenant_id = str(uuid4()) - with patch("services.rag_pipeline.rag_pipeline.current_user", fake_user): - with pytest.raises(ValueError, match="Customized pipeline template not found"): - RagPipelineService.delete_customized_pipeline_template(str(uuid4())) + with pytest.raises(ValueError, match="Customized pipeline template not found"): + RagPipelineService.delete_customized_pipeline_template(str(uuid4()), tenant_id) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index f3ab9eb3da8..5844441e6a5 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -1,6 +1,6 @@ from __future__ import annotations -from unittest.mock import Mock, patch +from unittest.mock import patch from uuid import uuid4 import pytest @@ -8,6 +8,7 @@ from flask import Flask from sqlalchemy import select from sqlalchemy.orm import Session +from models import Account, Tenant from models.dataset import Dataset, DatasetMetadataBinding, Document from models.enums import DataSourceType, DocumentCreatedFrom from services.entities.knowledge_entities.knowledge_entities import ( @@ -33,7 +34,7 @@ def _create_dataset(db_session: Session, *, tenant_id: str, built_in_field_enabl def _create_document( - db_session: Session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None + db_session: Session, *, dataset_id: str, tenant_id: str, doc_metadata: dict[str, str] | None = None ) -> Document: document = Document( tenant_id=tenant_id, @@ -63,18 +64,21 @@ class TestMetadataPartialUpdate: return str(uuid4()) @pytest.fixture - def mock_current_account(self, user_id, tenant_id): - account = Mock(id=user_id, current_tenant_id=tenant_id) - with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)): - yield account + def current_account(self, user_id: str, tenant_id: str) -> Account: + account = Account(name="Test User", email=f"{user_id}@example.com") + account.id = user_id + tenant = Tenant(name="Test Tenant") + tenant.id = tenant_id + account._current_tenant = tenant + return account def test_partial_update_merges_metadata( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id: str, - mock_current_account, - ): + current_account: Account, + ) -> None: dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( db_session_with_containers, @@ -91,7 +95,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(dataset, metadata_args, current_account) db_session_with_containers.expire_all() updated_doc = db_session_with_containers.get(Document, document.id) @@ -104,8 +108,8 @@ class TestMetadataPartialUpdate: flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id: str, - mock_current_account, - ): + current_account: Account, + ) -> None: dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( db_session_with_containers, @@ -122,7 +126,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(dataset, metadata_args, current_account) db_session_with_containers.expire_all() updated_doc = db_session_with_containers.get(Document, document.id) @@ -134,10 +138,10 @@ class TestMetadataPartialUpdate: self, flask_app_with_containers: Flask, db_session_with_containers: Session, - tenant_id, - user_id, - mock_current_account, - ): + tenant_id: str, + user_id: str, + current_account: Account, + ) -> None: dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( db_session_with_containers, @@ -164,7 +168,7 @@ class TestMetadataPartialUpdate: ) metadata_args = MetadataOperationData(operation_data=[operation]) - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(dataset, metadata_args, current_account) db_session_with_containers.expire_all() bindings = db_session_with_containers.scalars( @@ -180,8 +184,8 @@ class TestMetadataPartialUpdate: flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id: str, - mock_current_account, - ): + current_account: Account, + ) -> None: dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( db_session_with_containers, @@ -200,4 +204,4 @@ class TestMetadataPartialUpdate: with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): with pytest.raises(RuntimeError, match="database connection lost"): - MetadataService.update_documents_metadata(dataset, metadata_args) + MetadataService.update_documents_metadata(dataset, metadata_args, current_account) diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index 8b1349be9a8..0c9e3830430 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -1,4 +1,6 @@ -from unittest.mock import create_autospec, patch +from collections.abc import Generator +from typing import TypedDict +from unittest.mock import Mock, patch import pytest from faker import Faker @@ -6,21 +8,25 @@ from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document -from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom +from models.enums import DataSourceType, DocumentCreatedFrom from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService +class MetadataServiceDeps(TypedDict): + redis_client: Mock + document_service: Mock + + class TestMetadataService: """Integration tests for MetadataService using testcontainers.""" @pytest.fixture - def mock_external_service_dependencies(self): + def mock_external_service_dependencies(self) -> Generator[MetadataServiceDeps, None, None]: """Mock setup for external service dependencies.""" with ( - patch("libs.login.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.metadata_service.redis_client") as mock_redis_client, patch("services.dataset_service.DocumentService") as mock_document_service, ): @@ -30,12 +36,15 @@ class TestMetadataService: mock_redis_client.delete.return_value = 1 yield { - "current_user": mock_current_user, "redis_client": mock_redis_client, "document_service": mock_document_service, } - def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): + def _create_test_account_and_tenant( + self, + db_session_with_containers: Session, + mock_external_service_dependencies: MetadataServiceDeps, + ) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -53,7 +62,7 @@ class TestMetadataService: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) @@ -62,7 +71,7 @@ class TestMetadataService: # Create tenant for the account tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -83,8 +92,12 @@ class TestMetadataService: return account, tenant def _create_test_dataset( - self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant - ): + self, + db_session_with_containers: Session, + mock_external_service_dependencies: MetadataServiceDeps, + account: Account, + tenant: Tenant, + ) -> Dataset: """ Helper method to create a test dataset for testing. @@ -114,8 +127,12 @@ class TestMetadataService: return dataset def _create_test_document( - self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account - ): + self, + db_session_with_containers: Session, + mock_external_service_dependencies: MetadataServiceDeps, + dataset: Dataset, + account: Account, + ) -> Document: """ Helper method to create a test document for testing. @@ -149,7 +166,9 @@ class TestMetadataService: return document - def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): + def test_create_metadata_success( + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful metadata creation with valid parameters. """ @@ -161,14 +180,10 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") + metadata_args = MetadataArgs(type="string", name="test_metadata") # Act: Execute the method under test - result = MetadataService.create_metadata(dataset.id, metadata_args) + result = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Assert: Verify the expected outcomes assert result is not None @@ -185,8 +200,8 @@ class TestMetadataService: assert result.created_at is not None def test_create_metadata_name_too_long( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata creation fails when name exceeds 255 characters. """ @@ -198,20 +213,16 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - long_name = "a" * 256 # 256 characters, exceeding 255 limit - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name) + metadata_args = MetadataArgs(type="string", name=long_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): - MetadataService.create_metadata(dataset.id, metadata_args) + MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) def test_create_metadata_name_already_exists( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata creation fails when name already exists in the same dataset. """ @@ -223,24 +234,20 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create first metadata - first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name") - MetadataService.create_metadata(dataset.id, first_metadata_args) + first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id) # Try to create second metadata with same name - second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name") + second_metadata_args = MetadataArgs(type="number", name="duplicate_name") # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists."): - MetadataService.create_metadata(dataset.id, second_metadata_args) + MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id) def test_create_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata creation fails when name conflicts with built-in field names. """ @@ -252,21 +259,17 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Try to create metadata with built-in field name built_in_field_name = BuiltInField.document_name - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name) + metadata_args = MetadataArgs(type="string", name=built_in_field_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): - MetadataService.create_metadata(dataset.id, metadata_args) + MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) def test_update_metadata_name_success( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful metadata name update with valid parameters. """ @@ -278,17 +281,13 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata first - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Act: Execute the method under test new_name = "new_name" - result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name) + result = MetadataService.update_metadata_name(dataset.id, metadata.id, new_name, account, tenant.id) # Assert: Verify the expected outcomes assert result is not None @@ -302,8 +301,8 @@ class TestMetadataService: assert result.name == new_name def test_update_metadata_name_too_long( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata name update fails when new name exceeds 255 characters. """ @@ -315,24 +314,20 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata first - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Try to update with too long name long_name = "a" * 256 # 256 characters, exceeding 255 limit # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): - MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) + MetadataService.update_metadata_name(dataset.id, metadata.id, long_name, account, tenant.id) def test_update_metadata_name_already_exists( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata name update fails when new name already exists in the same dataset. """ @@ -344,24 +339,20 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create two metadata entries - first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata") - first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) + first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args, account, tenant.id) - second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata") - second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) + second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args, account, tenant.id) # Try to update first metadata with second metadata's name with pytest.raises(ValueError, match="Metadata name already exists."): - MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") + MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata", account, tenant.id) def test_update_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata name update fails when new name conflicts with built-in field names. """ @@ -373,23 +364,19 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata first - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="old_name") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Try to update with built-in field name built_in_field_name = BuiltInField.document_name with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): - MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) + MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name, account, tenant.id) def test_update_metadata_name_not_found( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata name update fails when metadata ID does not exist. """ @@ -401,10 +388,6 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Try to update non-existent metadata import uuid @@ -412,12 +395,14 @@ class TestMetadataService: new_name = "new_name" # Act: Execute the method under test - result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name) + result = MetadataService.update_metadata_name(dataset.id, fake_metadata_id, new_name, account, tenant.id) # Assert: Verify the method returns None when metadata is not found assert result is None - def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): + def test_delete_metadata_success( + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful metadata deletion with valid parameters. """ @@ -429,13 +414,9 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata first - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Act: Execute the method under test result = MetadataService.delete_metadata(dataset.id, metadata.id) @@ -449,7 +430,9 @@ class TestMetadataService: deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None - def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): + def test_delete_metadata_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata deletion fails when metadata ID does not exist. """ @@ -461,10 +444,6 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Try to delete non-existent metadata import uuid @@ -477,8 +456,8 @@ class TestMetadataService: assert result is None def test_delete_metadata_with_document_bindings( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata deletion successfully removes document metadata bindings. """ @@ -493,13 +472,9 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, dataset, account ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Create metadata binding binding = DatasetMetadataBinding( @@ -531,7 +506,9 @@ class TestMetadataService: # Note: The service attempts to update document metadata but may not succeed # due to mock configuration. The main functionality (metadata deletion) is verified. - def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies): + def test_get_built_in_fields_success( + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful retrieval of built-in metadata fields. """ @@ -557,8 +534,8 @@ class TestMetadataService: assert "time" in field_types def test_enable_built_in_field_success( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful enabling of built-in fields for a dataset. """ @@ -573,10 +550,6 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, dataset, account ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ document @@ -591,14 +564,14 @@ class TestMetadataService: # Assert: Verify the expected outcomes db_session_with_containers.refresh(dataset) - assert dataset.built_in_field_enabled is True + assert dataset.built_in_field_enabled # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (enabling built-in fields) is verified def test_enable_built_in_field_already_enabled( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test enabling built-in fields when they are already enabled. """ @@ -610,10 +583,6 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Enable built-in fields first dataset.built_in_field_enabled = True @@ -621,7 +590,9 @@ class TestMetadataService: db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id - mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[ + Document + ]() # Act: Execute the method under test MetadataService.enable_built_in_field(dataset) @@ -631,8 +602,8 @@ class TestMetadataService: assert dataset.built_in_field_enabled is True def test_enable_built_in_field_with_no_documents( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test enabling built-in fields for a dataset with no documents. """ @@ -644,12 +615,10 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Mock DocumentService.get_working_documents_by_dataset_id to return empty list - mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[ + Document + ]() # Act: Execute the method under test MetadataService.enable_built_in_field(dataset) @@ -657,11 +626,11 @@ class TestMetadataService: # Assert: Verify the expected outcomes db_session_with_containers.refresh(dataset) - assert dataset.built_in_field_enabled is True + assert dataset.built_in_field_enabled def test_disable_built_in_field_success( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful disabling of built-in fields for a dataset. """ @@ -676,10 +645,6 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, dataset, account ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Enable built-in fields first dataset.built_in_field_enabled = True @@ -713,8 +678,8 @@ class TestMetadataService: # The main functionality (disabling built-in fields) is verified def test_disable_built_in_field_already_disabled( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test disabling built-in fields when they are already disabled. """ @@ -726,15 +691,13 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Verify dataset starts with built-in fields disabled assert dataset.built_in_field_enabled is False # Mock DocumentService.get_working_documents_by_dataset_id - mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[ + Document + ]() # Act: Execute the method under test MetadataService.disable_built_in_field(dataset) @@ -742,11 +705,11 @@ class TestMetadataService: # Assert: Verify the method returns early without changes db_session_with_containers.refresh(dataset) - assert dataset.built_in_field_enabled is False + assert not dataset.built_in_field_enabled def test_disable_built_in_field_with_no_documents( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test disabling built-in fields for a dataset with no documents. """ @@ -758,10 +721,6 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Enable built-in fields first dataset.built_in_field_enabled = True @@ -769,7 +728,9 @@ class TestMetadataService: db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id to return empty list - mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] + mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = list[ + Document + ]() # Act: Execute the method under test MetadataService.disable_built_in_field(dataset) @@ -779,8 +740,8 @@ class TestMetadataService: assert dataset.built_in_field_enabled is False def test_update_documents_metadata_success( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful update of documents metadata. """ @@ -795,13 +756,9 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, dataset, account ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Mock DocumentService.get_document mock_external_service_dependencies["document_service"].get_document.return_value = document @@ -820,7 +777,7 @@ class TestMetadataService: operation_data = MetadataOperationData(operation_data=[operation]) # Act: Execute the method under test - MetadataService.update_documents_metadata(dataset, operation_data) + MetadataService.update_documents_metadata(dataset, operation_data, account) # Assert: Verify the expected outcomes @@ -841,8 +798,8 @@ class TestMetadataService: assert binding.dataset_id == dataset.id def test_update_documents_metadata_with_built_in_fields_enabled( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test update of documents metadata when built-in fields are enabled. """ @@ -863,13 +820,9 @@ class TestMetadataService: db_session_with_containers.add(dataset) db_session_with_containers.commit() - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Mock DocumentService.get_document mock_external_service_dependencies["document_service"].get_document.return_value = document @@ -888,7 +841,7 @@ class TestMetadataService: operation_data = MetadataOperationData(operation_data=[operation]) # Act: Execute the method under test - MetadataService.update_documents_metadata(dataset, operation_data) + MetadataService.update_documents_metadata(dataset, operation_data, account) # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields @@ -901,8 +854,8 @@ class TestMetadataService: # The main functionality (custom metadata update) is verified def test_update_documents_metadata_document_not_found( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test update of documents metadata when document is not found. """ @@ -914,13 +867,9 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Create metadata operation data from services.entities.knowledge_entities.knowledge_entities import ( @@ -941,11 +890,11 @@ class TestMetadataService: # Act & Assert: The method should raise ValueError("Document not found.") # because the exception is now re-raised after rollback with pytest.raises(ValueError, match="Document not found"): - MetadataService.update_documents_metadata(dataset, operation_data) + MetadataService.update_documents_metadata(dataset, operation_data, account) def test_knowledge_base_metadata_lock_check_dataset_id( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata lock check for dataset operations. """ @@ -967,8 +916,8 @@ class TestMetadataService: assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" def test_knowledge_base_metadata_lock_check_document_id( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata lock check for document operations. """ @@ -990,8 +939,8 @@ class TestMetadataService: assert call_args[0][0] == f"document_metadata_lock_{document_id}" def test_knowledge_base_metadata_lock_check_lock_exists( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata lock check when lock already exists. """ @@ -1007,8 +956,8 @@ class TestMetadataService: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) def test_knowledge_base_metadata_lock_check_document_lock_exists( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test metadata lock check when document lock already exists. """ @@ -1022,8 +971,8 @@ class TestMetadataService: MetadataService.knowledge_base_metadata_lock_check(None, document_id) def test_get_dataset_metadatas_success( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test successful retrieval of dataset metadata information. """ @@ -1035,13 +984,9 @@ class TestMetadataService: db_session_with_containers, mock_external_service_dependencies, account, tenant ) - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Create document and metadata binding document = self._create_test_document( @@ -1079,8 +1024,8 @@ class TestMetadataService: assert result["built_in_field_enabled"] is False def test_get_dataset_metadatas_with_built_in_fields_enabled( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test retrieval of dataset metadata when built-in fields are enabled. """ @@ -1098,13 +1043,9 @@ class TestMetadataService: db_session_with_containers.add(dataset) db_session_with_containers.commit() - # Setup mocks - mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id - mock_external_service_dependencies["current_user"].id = account.id - # Create metadata - metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") - metadata = MetadataService.create_metadata(dataset.id, metadata_args) + metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata = MetadataService.create_metadata(dataset.id, metadata_args, account, tenant.id) # Act: Execute the method under test result = MetadataService.get_dataset_metadatas(dataset) @@ -1122,8 +1063,8 @@ class TestMetadataService: assert result["built_in_field_enabled"] is True def test_get_dataset_metadatas_no_metadata( - self, db_session_with_containers: Session, mock_external_service_dependencies - ): + self, db_session_with_containers: Session, mock_external_service_dependencies: MetadataServiceDeps + ) -> None: """ Test retrieval of dataset metadata when no metadata exists. """ diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 5a66bc4e92f..bca2d73ad9f 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -15,6 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline import ( PipelineTemplateListApi, PublishCustomizedPipelineTemplateApi, ) +from models.account import Account from models.dataset import PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity @@ -50,24 +51,31 @@ def _payload() -> dict[str, object]: } +def _account() -> Account: + account = Account(name="Test User", email="test@example.com") + account.id = "account-1" + return account + + class TestPipelineTemplateListApi: def test_get_uses_query_defaults_and_serializes_nullable_fields(self, app: Flask) -> None: api = PipelineTemplateListApi() method = unwrap(api.get) - service_calls: list[tuple[str, str]] = [] + tenant_id = "tenant-1" + service_calls: list[tuple[str, str, str]] = [] - def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]: - service_calls.append((template_type, language)) + def get_pipeline_templates(template_type: str, language: str, current_tenant_id: str) -> dict[str, object]: + service_calls.append((template_type, language, current_tenant_id)) return {"pipeline_templates": [_template_item()]} with ( app.test_request_context("/rag/pipeline/templates"), patch.object(module.RagPipelineService, "get_pipeline_templates", side_effect=get_pipeline_templates), ): - response, status = method(api) + response, status = method(api, tenant_id) assert status == 200 - assert service_calls == [("built-in", "en-US")] + assert service_calls == [("built-in", "en-US", tenant_id)] assert response == { "pipeline_templates": [ { @@ -81,21 +89,22 @@ class TestPipelineTemplateListApi: def test_get_passes_explicit_query_to_service(self, app: Flask) -> None: api = PipelineTemplateListApi() method = unwrap(api.get) - service_calls: list[tuple[str, str]] = [] + tenant_id = "tenant-1" + service_calls: list[tuple[str, str, str]] = [] - def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]: - service_calls.append((template_type, language)) + def get_pipeline_templates(template_type: str, language: str, current_tenant_id: str) -> dict[str, object]: + service_calls.append((template_type, language, current_tenant_id)) return {"pipeline_templates": []} with ( app.test_request_context("/rag/pipeline/templates?type=customized&language=ja-JP"), patch.object(module.RagPipelineService, "get_pipeline_templates", side_effect=get_pipeline_templates), ): - response, status = method(api) + response, status = method(api, tenant_id) assert status == 200 assert response == {"pipeline_templates": []} - assert service_calls == [("customized", "ja-JP")] + assert service_calls == [("customized", "ja-JP", tenant_id)] class TestPipelineTemplateDetailApi: @@ -140,22 +149,28 @@ class TestCustomizedPipelineTemplateApi: api = CustomizedPipelineTemplateApi() method = unwrap(api.patch) payload = _payload() - service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = [] + account = _account() + tenant_id = "tenant-1" + service_calls: list[tuple[str, PipelineTemplateInfoEntity, Account, str]] = [] - def update_template(template_id: str, template_info: PipelineTemplateInfoEntity) -> None: - service_calls.append((template_id, template_info)) + def update_template( + template_id: str, template_info: PipelineTemplateInfoEntity, current_user: Account, current_tenant_id: str + ) -> None: + service_calls.append((template_id, template_info, current_user, current_tenant_id)) with ( app.test_request_context("/rag/pipeline/customized/templates/template-1", method="PATCH", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object(module.RagPipelineService, "update_customized_pipeline_template", side_effect=update_template), ): - response, status = method(api, "template-1") + response, status = method(api, tenant_id, account, "template-1") assert (response, status) == ("", 204) assert len(service_calls) == 1 - template_id, template_info = service_calls[0] + template_id, template_info, current_user, current_tenant_id = service_calls[0] assert template_id == "template-1" + assert current_user is account + assert current_tenant_id == tenant_id assert template_info.name == "Updated template" assert template_info.description == "Updated description" assert template_info.icon_info.model_dump() == { @@ -172,22 +187,28 @@ class TestCustomizedPipelineTemplateApi: "name": "Updated template", "description": "Updated description", } - service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = [] + account = _account() + tenant_id = "tenant-1" + service_calls: list[tuple[str, PipelineTemplateInfoEntity, Account, str]] = [] - def update_template(template_id: str, template_info: PipelineTemplateInfoEntity) -> None: - service_calls.append((template_id, template_info)) + def update_template( + template_id: str, template_info: PipelineTemplateInfoEntity, current_user: Account, current_tenant_id: str + ) -> None: + service_calls.append((template_id, template_info, current_user, current_tenant_id)) with ( app.test_request_context("/rag/pipeline/customized/templates/template-1", method="PATCH", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object(module.RagPipelineService, "update_customized_pipeline_template", side_effect=update_template), ): - response, status = method(api, "template-1") + response, status = method(api, tenant_id, account, "template-1") assert (response, status) == ("", 204) assert len(service_calls) == 1 - template_id, template_info = service_calls[0] + template_id, template_info, current_user, current_tenant_id = service_calls[0] assert template_id == "template-1" + assert current_user is account + assert current_tenant_id == tenant_id assert template_info.icon_info.model_dump() == { "icon": "", "icon_background": None, @@ -198,19 +219,20 @@ class TestCustomizedPipelineTemplateApi: def test_delete_returns_empty_204(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() method = unwrap(api.delete) - deleted_template_ids: list[str] = [] + tenant_id = "tenant-1" + deleted_templates: list[tuple[str, str]] = [] - def delete_template(template_id: str) -> None: - deleted_template_ids.append(template_id) + def delete_template(template_id: str, current_tenant_id: str) -> None: + deleted_templates.append((template_id, current_tenant_id)) with ( app.test_request_context("/rag/pipeline/customized/templates/template-1", method="DELETE"), patch.object(module.RagPipelineService, "delete_customized_pipeline_template", side_effect=delete_template), ): - response, status = method(api, "template-1") + response, status = method(api, tenant_id, "template-1") assert (response, status) == ("", 204) - assert deleted_template_ids == ["template-1"] + assert deleted_templates == [("template-1", tenant_id)] def test_post_exports_yaml_from_orm_template(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() @@ -292,21 +314,25 @@ class TestPublishCustomizedPipelineTemplateApi: api = PublishCustomizedPipelineTemplateApi() method = unwrap(api.post) payload = _payload() - service_calls: list[tuple[str, dict[str, object]]] = [] + account = _account() + tenant_id = "tenant-1" + service_calls: list[tuple[str, dict[str, object], Account, str]] = [] class Service: - def publish_customized_pipeline_template(self, pipeline_id: str, data: dict[str, object]) -> None: - service_calls.append((pipeline_id, data)) + def publish_customized_pipeline_template( + self, pipeline_id: str, data: dict[str, object], current_user: Account, current_tenant_id: str + ) -> None: + service_calls.append((pipeline_id, data, current_user, current_tenant_id)) with ( app.test_request_context("/rag/pipelines/pipeline-1/customized/publish", method="POST", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object(module, "RagPipelineService", Service), ): - response, status = method(api, "pipeline-1") + response, status = method(api, tenant_id, account, "pipeline-1") assert (response, status) == ("", 204) - assert service_calls == [("pipeline-1", payload)] + assert service_calls == [("pipeline-1", payload, account, tenant_id)] def test_post_allows_missing_icon_info_for_publish_service_fallback(self, app: Flask) -> None: api = PublishCustomizedPipelineTemplateApi() @@ -315,18 +341,22 @@ class TestPublishCustomizedPipelineTemplateApi: "name": "Published template", "description": "Description", } - service_calls: list[tuple[str, dict[str, object]]] = [] + account = _account() + tenant_id = "tenant-1" + service_calls: list[tuple[str, dict[str, object], Account, str]] = [] class Service: - def publish_customized_pipeline_template(self, pipeline_id: str, data: dict[str, object]) -> None: - service_calls.append((pipeline_id, data)) + def publish_customized_pipeline_template( + self, pipeline_id: str, data: dict[str, object], current_user: Account, current_tenant_id: str + ) -> None: + service_calls.append((pipeline_id, data, current_user, current_tenant_id)) with ( app.test_request_context("/rag/pipelines/pipeline-1/customized/publish", method="POST", json=payload), patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), patch.object(module, "RagPipelineService", Service), ): - response, status = method(api, "pipeline-1") + response, status = method(api, tenant_id, account, "pipeline-1") assert (response, status) == ("", 204) assert service_calls == [ @@ -341,5 +371,7 @@ class TestPublishCustomizedPipelineTemplateApi: "icon_url": None, }, }, + account, + tenant_id, ) ] diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py index faedd4d7e1d..3de780f3bbb 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -1,4 +1,5 @@ import uuid +from inspect import unwrap from unittest.mock import PropertyMock, patch import pytest @@ -8,16 +9,10 @@ from werkzeug.exceptions import NotFound from controllers.console import console_ns from controllers.console.datasets.hit_testing import HitTestingApi +from models.account import Account, Tenant, TenantAccountRole from models.dataset import Dataset -def unwrap(func): - """Recursively unwrap decorated functions.""" - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - @pytest.fixture def app(): app = Flask("test_hit_testing") @@ -35,6 +30,17 @@ def dataset(): return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1") +@pytest.fixture +def account() -> Account: + account = Account(name="User", email="user@example.com") + account.id = "account-1" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + account._current_tenant = tenant + account.role = TenantAccountRole.OWNER + return account + + def hit_testing_record() -> dict[str, object]: return { "segment": { @@ -98,7 +104,7 @@ def bypass_decorators(mocker: MockerFixture): class TestHitTestingApi: - def test_hit_testing_success(self, app: Flask, dataset, dataset_id): + def test_hit_testing_success(self, app: Flask, dataset, dataset_id, account: Account): api = HitTestingApi() method = unwrap(api.post) @@ -129,13 +135,13 @@ class TestHitTestingApi: return_value={"query": {"content": "what is vector search"}, "records": []}, ), ): - result = method(api, dataset_id) + result = method(api, account, "tenant-1", dataset_id) assert "query" in result assert "records" in result assert result["records"] == [] - def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id): + def test_hit_testing_success_with_optional_record_fields(self, app: Flask, dataset, dataset_id, account: Account): api = HitTestingApi() method = unwrap(api.post) @@ -167,7 +173,7 @@ class TestHitTestingApi: return_value={"query": {"content": payload["query"]}, "records": records}, ), ): - result = method(api, dataset_id) + result = method(api, account, "tenant-1", dataset_id) assert result["query"] == {"content": payload["query"]} assert result["records"][0]["segment"]["keywords"] == [] @@ -175,7 +181,7 @@ class TestHitTestingApi: assert result["records"][0]["files"] == [] assert result["records"][0]["score"] is None - def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id): + def test_hit_testing_dataset_not_found(self, app: Flask, dataset_id, account: Account): api = HitTestingApi() method = unwrap(api.post) @@ -198,9 +204,9 @@ class TestHitTestingApi: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, dataset_id) + method(api, account, "tenant-1", dataset_id) - def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id): + def test_hit_testing_invalid_args(self, app: Flask, dataset, dataset_id, account: Account): api = HitTestingApi() method = unwrap(api.post) @@ -228,4 +234,4 @@ class TestHitTestingApi: ), ): with pytest.raises(ValueError, match="Invalid parameters"): - method(api, dataset_id) + method(api, account, "tenant-1", dataset_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index 072aa559dff..0fcf0df5262 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -21,7 +21,7 @@ from core.errors.error import ( QuotaExceededError, ) from graphon.model_runtime.errors.invoke import InvokeError -from models.account import Account +from models.account import Account, Tenant, TenantAccountRole from models.dataset import Dataset from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -29,19 +29,15 @@ from services.hit_testing_service import HitTestingService @pytest.fixture def account(): - acc = MagicMock(spec=Account) + acc = Account(name="User", email="user@example.com") + acc.id = "account-1" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + acc._current_tenant = tenant + acc.role = TenantAccountRole.OWNER return acc -@pytest.fixture(autouse=True) -def patch_current_user(mocker, account): - """Patch current_user to a valid Account.""" - mocker.patch( - "controllers.console.datasets.hit_testing_base.current_user", - account, - ) - - @pytest.fixture def dataset(): return Dataset(id="dataset-1", tenant_id="tenant-1", name="Dataset", created_by="account-1") @@ -86,7 +82,7 @@ def hit_testing_record() -> dict[str, object]: class TestGetAndValidateDataset: - def test_success(self, dataset): + def test_success(self, dataset, account): with ( patch.object( DatasetService, @@ -98,20 +94,20 @@ class TestGetAndValidateDataset: "check_dataset_permission", ), ): - result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1") assert result == dataset - def test_dataset_not_found(self): + def test_dataset_not_found(self, account): with patch.object( DatasetService, "get_dataset", return_value=None, ): with pytest.raises(NotFound, match="Dataset not found"): - DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1") - def test_permission_denied(self, dataset): + def test_permission_denied(self, dataset, account): with ( patch.object( DatasetService, @@ -125,7 +121,7 @@ class TestGetAndValidateDataset: ), ): with pytest.raises(Forbidden, match="no access"): - DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1", account, "tenant-1") class TestHitTestingArgsCheck: @@ -164,7 +160,7 @@ class TestParseArgs: class TestPerformHitTesting: - def test_success(self, dataset): + def test_success(self, dataset, account): response = { "query": {"content": "hello"}, "records": [], @@ -175,12 +171,12 @@ class TestPerformHitTesting: "retrieve", return_value=response, ): - result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") assert result["query"] == {"content": "hello"} assert result["records"] == [] - def test_success_prepares_nullable_list_fields(self, dataset): + def test_success_prepares_nullable_list_fields(self, dataset, account): response = { "query": {"content": "hello"}, "records": [hit_testing_record()], @@ -191,7 +187,7 @@ class TestPerformHitTesting: "retrieve", return_value=response, ): - result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") assert result["query"] == {"content": "hello"} record = result["records"][0] @@ -203,7 +199,7 @@ class TestPerformHitTesting: assert record["tsne_position"] is None assert record["summary"] is None - def test_invalid_query_response_raises_value_error(self, dataset): + def test_invalid_query_response_raises_value_error(self, dataset, account): with ( patch.object( HitTestingService, @@ -212,7 +208,7 @@ class TestPerformHitTesting: ), pytest.raises(ValueError, match="Invalid hit testing query response"), ): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") def test_invalid_records_response_raises_value_error(self): with pytest.raises(ValueError, match="Invalid hit testing records response"): @@ -222,74 +218,74 @@ class TestPerformHitTesting: with pytest.raises(ValueError, match="Invalid hit testing record response"): DatasetsHitTestingBase._prepare_hit_testing_records(["record"]) - def test_index_not_initialized(self, dataset): + def test_index_not_initialized(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=services.errors.index.IndexNotInitializedError(), ): with pytest.raises(DatasetNotInitializedError): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_provider_token_not_init(self, dataset): + def test_provider_token_not_init(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=ProviderTokenNotInitError("token missing"), ): with pytest.raises(ProviderNotInitializeError): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_quota_exceeded(self, dataset): + def test_quota_exceeded(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=QuotaExceededError(), ): with pytest.raises(ProviderQuotaExceededError): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_model_not_supported(self, dataset): + def test_model_not_supported(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=ModelCurrentlyNotSupportError(), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_llm_bad_request(self, dataset): + def test_llm_bad_request(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=LLMBadRequestError("bad request"), ): with pytest.raises(ProviderNotInitializeError): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_invoke_error(self, dataset): + def test_invoke_error(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=InvokeError("invoke failed"), ): with pytest.raises(CompletionRequestError): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_value_error(self, dataset): + def test_value_error(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=ValueError("bad args"), ): with pytest.raises(ValueError, match="bad args"): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") - def test_unexpected_error(self, dataset): + def test_unexpected_error(self, dataset, account): with patch.object( HitTestingService, "retrieve", side_effect=Exception("boom"), ): with pytest.raises(InternalServerError, match="boom"): - DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}, account, "tenant-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py index 3015ed6604b..785c0ac09f2 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -1,4 +1,5 @@ import uuid +from inspect import unwrap from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -14,6 +15,7 @@ from controllers.console.datasets.metadata import ( DatasetMetadataCreateApi, DocumentMetadataEditApi, ) +from models.account import Account from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( MetadataArgs, @@ -22,13 +24,6 @@ from services.entities.knowledge_entities.knowledge_entities import ( from services.metadata_service import MetadataService -def unwrap(func): - """Recursively unwrap decorated functions.""" - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - @pytest.fixture def app(): app = Flask("test_dataset_metadata") @@ -37,8 +32,8 @@ def app(): @pytest.fixture -def current_user(): - user = MagicMock() +def current_user() -> Account: + user = Account(name="Test User", email="test@example.com") user.id = "user-1" return user @@ -116,7 +111,7 @@ class TestDatasetMetadataCreateApi: return_value={"id": "m1", "type": "string", "name": "author"}, ), ): - result, status = method(api, current_user, dataset_id) + result, status = method(api, "tenant-1", current_user, dataset_id) assert status == 201 assert result["type"] == "string" @@ -151,7 +146,7 @@ class TestDatasetMetadataCreateApi: ), ): with pytest.raises(NotFound, match="Dataset not found"): - method(api, current_user, dataset_id) + method(api, "tenant-1", current_user, dataset_id) class TestDatasetMetadataGetApi: @@ -227,7 +222,7 @@ class TestDatasetMetadataApi: return_value={"id": "m1", "type": "string", "name": "updated-name"}, ), ): - result, status = method(api, current_user, dataset_id, metadata_id) + result, status = method(api, "tenant-1", current_user, dataset_id, metadata_id) assert status == 200 assert result["type"] == "string" diff --git a/api/tests/unit_tests/controllers/console/explore/test_completion.py b/api/tests/unit_tests/controllers/console/explore/test_completion.py index 420392f1dfa..8b9121c4d7f 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_completion.py +++ b/api/tests/unit_tests/controllers/console/explore/test_completion.py @@ -1,3 +1,4 @@ +from inspect import unwrap from unittest.mock import MagicMock, PropertyMock, patch import pytest @@ -15,15 +16,11 @@ from models.model import AppMode from services.errors.llm import InvokeRateLimitError -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - @pytest.fixture def user(): - return MagicMock(spec=Account) + account = Account(name="User", email="user.com") + account.id = "uid" + return account @pytest.fixture @@ -59,7 +56,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -71,18 +67,18 @@ class TestCompletionApi: return_value=("ok", 200), ), ): - result = method(completion_app) + result = method(api, user, completion_app) assert result == ("ok", 200) - def test_post_wrong_app_mode(self): + def test_post_wrong_app_mode(self, user): api = completion_module.CompletionApi() method = unwrap(api.post) installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT)) with pytest.raises(NotCompletionAppError): - method(installed_app) + method(api, user, installed_app) def test_conversation_completed(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -91,7 +87,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -99,7 +94,7 @@ class TestCompletionApi: ), ): with pytest.raises(ConversationCompletedError): - method(completion_app) + method(api, user, completion_app) def test_internal_error(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -108,7 +103,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -116,7 +110,7 @@ class TestCompletionApi: ), ): with pytest.raises(InternalServerError): - method(completion_app) + method(api, user, completion_app) def test_conversation_not_exists(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -125,7 +119,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -133,7 +126,7 @@ class TestCompletionApi: ), ): with pytest.raises(completion_module.NotFound): - method(completion_app) + method(api, user, completion_app) def test_app_unavailable(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -142,7 +135,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -150,7 +142,7 @@ class TestCompletionApi: ), ): with pytest.raises(completion_module.AppUnavailableError): - method(completion_app) + method(api, user, completion_app) def test_provider_not_initialized(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -159,7 +151,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -167,7 +158,7 @@ class TestCompletionApi: ), ): with pytest.raises(completion_module.ProviderNotInitializeError): - method(completion_app) + method(api, user, completion_app) def test_quota_exceeded(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -176,7 +167,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -184,7 +174,7 @@ class TestCompletionApi: ), ): with pytest.raises(completion_module.ProviderQuotaExceededError): - method(completion_app) + method(api, user, completion_app) def test_model_not_supported(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -193,7 +183,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -201,7 +190,7 @@ class TestCompletionApi: ), ): with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError): - method(completion_app) + method(api, user, completion_app) def test_invoke_error(self, app: Flask, completion_app, user, payload_patch): api = completion_module.CompletionApi() @@ -210,7 +199,6 @@ class TestCompletionApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -218,7 +206,7 @@ class TestCompletionApi: ), ): with pytest.raises(completion_module.CompletionRequestError): - method(completion_app) + method(api, user, completion_app) class TestCompletionStopApi: @@ -250,7 +238,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -262,18 +249,18 @@ class TestChatApi: return_value=("ok", 200), ), ): - result = method(chat_app) + result = method(api, user, chat_app) assert result == ("ok", 200) - def test_post_not_chat_app(self): + def test_post_not_chat_app(self, user): api = completion_module.ChatApi() method = unwrap(api.post) installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) with pytest.raises(NotChatAppError): - method(installed_app) + method(api, user, installed_app) def test_rate_limit_error(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -282,7 +269,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -290,7 +276,7 @@ class TestChatApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(chat_app) + method(api, user, chat_app) def test_conversation_completed_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -299,7 +285,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -307,7 +292,7 @@ class TestChatApi: ), ): with pytest.raises(ConversationCompletedError): - method(chat_app) + method(api, user, chat_app) def test_conversation_not_exists_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -316,7 +301,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -324,7 +308,7 @@ class TestChatApi: ), ): with pytest.raises(completion_module.NotFound): - method(chat_app) + method(api, user, chat_app) def test_app_unavailable_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -333,7 +317,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -341,7 +324,7 @@ class TestChatApi: ), ): with pytest.raises(completion_module.AppUnavailableError): - method(chat_app) + method(api, user, chat_app) def test_provider_not_initialized_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -350,7 +333,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -358,7 +340,7 @@ class TestChatApi: ), ): with pytest.raises(completion_module.ProviderNotInitializeError): - method(chat_app) + method(api, user, chat_app) def test_quota_exceeded_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -367,7 +349,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -375,7 +356,7 @@ class TestChatApi: ), ): with pytest.raises(completion_module.ProviderQuotaExceededError): - method(chat_app) + method(api, user, chat_app) def test_model_not_supported_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -384,7 +365,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -392,7 +372,7 @@ class TestChatApi: ), ): with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError): - method(chat_app) + method(api, user, chat_app) def test_invoke_error_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -401,7 +381,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -409,7 +388,7 @@ class TestChatApi: ), ): with pytest.raises(completion_module.CompletionRequestError): - method(chat_app) + method(api, user, chat_app) def test_internal_error_chat(self, app: Flask, chat_app, user, payload_patch): api = completion_module.ChatApi() @@ -418,7 +397,6 @@ class TestChatApi: with ( app.test_request_context("/", json={}), payload_patch, - patch.object(completion_module, "current_user", user), patch.object( completion_module.AppGenerateService, "generate", @@ -426,7 +404,7 @@ class TestChatApi: ), ): with pytest.raises(InternalServerError): - method(chat_app) + method(api, user, chat_app) class TestChatStopApi: diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 641209d1deb..be68a3beed6 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -1,4 +1,6 @@ +from inspect import unwrap as inspect_unwrap from io import BytesIO +from typing import Any from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -33,22 +35,24 @@ from models.model import AppMode from services.errors.conversation import ConversationNotExistsError from services.errors.llm import InvokeRateLimitError - -def unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +unwrap: Any = inspect_unwrap @pytest.fixture -def account(): - acc = MagicMock(spec=Account) +def account() -> Account: + acc = Account(name="User", email="user@example.com") acc.id = "u1" return acc +def _file_data() -> Any: + file_data: Any = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + return file_data + + @pytest.fixture -def trial_app_chat(): +def trial_app_chat() -> MagicMock: app = MagicMock() app.id = "a-chat" app.mode = AppMode.CHAT @@ -56,7 +60,7 @@ def trial_app_chat(): @pytest.fixture -def trial_app_completion(): +def trial_app_completion() -> MagicMock: app = MagicMock() app.id = "a-comp" app.mode = AppMode.COMPLETION @@ -64,7 +68,7 @@ def trial_app_completion(): @pytest.fixture -def trial_app_workflow(): +def trial_app_workflow() -> MagicMock: app = MagicMock() app.id = "a-workflow" app.mode = AppMode.WORKFLOW @@ -72,7 +76,7 @@ def trial_app_workflow(): @pytest.fixture -def valid_parameters(): +def valid_parameters() -> dict[str, object]: return { "user_input_form": [], "system_parameters": {}, @@ -88,41 +92,39 @@ def valid_parameters(): } -def test_trial_workflow_uses_trial_scoped_simple_account_model(): +def test_trial_workflow_uses_trial_scoped_simple_account_model() -> None: assert module.simple_account_model.name == "TrialSimpleAccount" assert hasattr(module.simple_account_model, "items") class TestTrialAppWorkflowRunApi: - def test_not_workflow_app(self, app: Flask): + def test_not_workflow_app(self, app: Flask, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(api, MagicMock(mode=AppMode.CHAT)) + method(api, account, MagicMock(mode=AppMode.CHAT)) - def test_success(self, app: Flask, trial_app_workflow, account): + def test_success(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}}), - patch.object(module, "current_user", account), patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(api, trial_app_workflow) + result = method(api, account, trial_app_workflow) assert result is not None - def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow, account): + def test_workflow_provider_not_init(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -130,15 +132,14 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) - def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow, account): + def test_workflow_quota_exceeded(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -146,15 +147,14 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) - def test_workflow_model_not_support(self, app: Flask, trial_app_workflow, account): + def test_workflow_model_not_support(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -162,15 +162,14 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) - def test_workflow_invoke_error(self, app: Flask, trial_app_workflow, account): + def test_workflow_invoke_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -178,15 +177,14 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(CompletionRequestError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) - def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow, account): + def test_workflow_rate_limit_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -194,15 +192,14 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) - def test_workflow_value_error(self, app: Flask, trial_app_workflow, account): + def test_workflow_value_error(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "files": []}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -210,15 +207,14 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(ValueError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) - def test_workflow_generic_exception(self, app: Flask, trial_app_workflow, account): + def test_workflow_generic_exception(self, app: Flask, trial_app_workflow: MagicMock, account: Account) -> None: api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "files": []}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -226,39 +222,37 @@ class TestTrialAppWorkflowRunApi: ), ): with pytest.raises(InternalServerError): - method(api, trial_app_workflow) + method(api, account, trial_app_workflow) class TestTrialChatApi: - def test_not_chat_app(self, app: Flask): + def test_not_chat_app(self, app: Flask, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with app.test_request_context("/", json={"inputs": {}, "query": "hi"}): with pytest.raises(NotChatAppError): - method(api, MagicMock(mode="completion")) + method(api, account, MagicMock(mode="completion")) - def test_success(self, app: Flask, trial_app_chat, account): + def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(api, trial_app_chat) + result = method(api, account, trial_app_chat) assert result is not None - def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat, account): + def test_chat_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -266,15 +260,14 @@ class TestTrialChatApi: ), ): with pytest.raises(NotFound): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_conversation_completed(self, app: Flask, trial_app_chat, account): + def test_chat_conversation_completed(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -282,15 +275,14 @@ class TestTrialChatApi: ), ): with pytest.raises(ConversationCompletedError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_app_config_broken(self, app: Flask, trial_app_chat, account): + def test_chat_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -298,15 +290,14 @@ class TestTrialChatApi: ), ): with pytest.raises(AppUnavailableError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_provider_not_init(self, app: Flask, trial_app_chat, account): + def test_chat_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -314,15 +305,14 @@ class TestTrialChatApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_quota_exceeded(self, app: Flask, trial_app_chat, account): + def test_chat_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -330,15 +320,14 @@ class TestTrialChatApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_model_not_support(self, app: Flask, trial_app_chat, account): + def test_chat_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -346,15 +335,14 @@ class TestTrialChatApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_invoke_error(self, app: Flask, trial_app_chat, account): + def test_chat_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -362,15 +350,14 @@ class TestTrialChatApi: ), ): with pytest.raises(CompletionRequestError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_rate_limit_error(self, app: Flask, trial_app_chat, account): + def test_chat_rate_limit_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -378,15 +365,14 @@ class TestTrialChatApi: ), ): with pytest.raises(InvokeRateLimitHttpError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_value_error(self, app: Flask, trial_app_chat, account): + def test_chat_value_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -394,15 +380,14 @@ class TestTrialChatApi: ), ): with pytest.raises(ValueError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_chat_generic_exception(self, app: Flask, trial_app_chat, account): + def test_chat_generic_exception(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": "hi"}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -410,39 +395,37 @@ class TestTrialChatApi: ), ): with pytest.raises(InternalServerError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) class TestTrialCompletionApi: - def test_not_completion_app(self, app: Flask): + def test_not_completion_app(self, app: Flask, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with app.test_request_context("/", json={"inputs": {}, "query": ""}): with pytest.raises(NotCompletionAppError): - method(api, MagicMock(mode=AppMode.CHAT)) + method(api, account, MagicMock(mode=AppMode.CHAT)) - def test_success(self, app: Flask, trial_app_completion, account): + def test_success(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(api, trial_app_completion) + result = method(api, account, trial_app_completion) assert result is not None - def test_completion_app_config_broken(self, app: Flask, trial_app_completion, account): + def test_completion_app_config_broken(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -450,15 +433,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(AppUnavailableError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_provider_not_init(self, app: Flask, trial_app_completion, account): + def test_completion_provider_not_init(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -466,15 +448,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_quota_exceeded(self, app: Flask, trial_app_completion, account): + def test_completion_quota_exceeded(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -482,15 +463,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(ProviderQuotaExceededError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_model_not_support(self, app: Flask, trial_app_completion, account): + def test_completion_model_not_support(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -498,15 +478,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_invoke_error(self, app: Flask, trial_app_completion, account): + def test_completion_invoke_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -514,15 +493,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(CompletionRequestError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_rate_limit_error(self, app: Flask, trial_app_completion, account): + def test_completion_rate_limit_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -530,15 +508,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(InternalServerError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_value_error(self, app: Flask, trial_app_completion, account): + def test_completion_value_error(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -546,15 +523,14 @@ class TestTrialCompletionApi: ), ): with pytest.raises(ValueError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) - def test_completion_generic_exception(self, app: Flask, trial_app_completion, account): + def test_completion_generic_exception(self, app: Flask, trial_app_completion: MagicMock, account: Account) -> None: api = module.TrialCompletionApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"inputs": {}, "query": ""}), - patch.object(module, "current_user", account), patch.object( module.AppGenerateService, "generate", @@ -562,42 +538,40 @@ class TestTrialCompletionApi: ), ): with pytest.raises(InternalServerError): - method(api, trial_app_completion) + method(api, account, trial_app_completion) class TestTrialMessageSuggestedQuestionApi: - def test_not_chat_app(self, app: Flask): + def test_not_chat_app(self, app: Flask, account: Account) -> None: api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) with app.test_request_context("/"): with pytest.raises(NotChatAppError): - method(MagicMock(mode="completion"), str(uuid4())) + method(api, account, MagicMock(mode="completion"), str(uuid4())) - def test_success(self, app: Flask, trial_app_chat, account): + def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch.object(module, "current_user", account), patch.object( module.MessageService, "get_suggested_questions_after_answer", return_value=["q1", "q2"], ), ): - result = method(trial_app_chat, str(uuid4())) + result = method(api, account, trial_app_chat, str(uuid4())) assert result == {"data": ["q1", "q2"]} - def test_conversation_not_exists(self, app: Flask, trial_app_chat, account): + def test_conversation_not_exists(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) with ( app.test_request_context("/"), - patch.object(module, "current_user", account), patch.object( module.MessageService, "get_suggested_questions_after_answer", @@ -605,18 +579,18 @@ class TestTrialMessageSuggestedQuestionApi: ), ): with pytest.raises(NotFound): - method(trial_app_chat, str(uuid4())) + method(api, account, trial_app_chat, str(uuid4())) class TestTrialAppParameterApi: - def test_app_unavailable(self): + def test_app_unavailable(self) -> None: api = module.TrialAppParameterApi() method = unwrap(api.get) with pytest.raises(AppUnavailableError): method(api, None) - def test_success_non_workflow(self, valid_parameters): + def test_success_non_workflow(self, valid_parameters: dict[str, object]) -> None: api = module.TrialAppParameterApi() method = unwrap(api.get) @@ -643,37 +617,33 @@ class TestTrialAppParameterApi: class TestTrialChatAudioApi: - def test_success(self, app: Flask, trial_app_chat, account): + def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_asr", return_value={"text": "hello"}), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(api, trial_app_chat) + result = method(api, account, trial_app_chat) assert result == {"text": "hello"} - def test_app_config_broken(self, app: Flask, trial_app_chat, account): + def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -681,20 +651,18 @@ class TestTrialChatAudioApi: ), ): with pytest.raises(module.AppUnavailableError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account): + def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -702,20 +670,18 @@ class TestTrialChatAudioApi: ), ): with pytest.raises(module.NoAudioUploadedError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_audio_too_large(self, app: Flask, trial_app_chat, account): + def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -723,20 +689,18 @@ class TestTrialChatAudioApi: ), ): with pytest.raises(module.AudioTooLargeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account): + def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -744,20 +708,18 @@ class TestTrialChatAudioApi: ), ): with pytest.raises(module.UnsupportedAudioTypeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_provider_not_support_tts(self, app: Flask, trial_app_chat, account): + def test_provider_not_support_tts(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -765,65 +727,59 @@ class TestTrialChatAudioApi: ), ): with pytest.raises(module.ProviderNotSupportSpeechToTextError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_provider_not_init(self, app: Flask, trial_app_chat, account): + def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_asr", side_effect=ProviderTokenNotInitError("test")), ): with pytest.raises(ProviderNotInitializeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_quota_exceeded(self, app: Flask, trial_app_chat, account): + def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_asr", side_effect=QuotaExceededError()), ): with pytest.raises(ProviderQuotaExceededError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) class TestTrialChatTextApi: - def test_success(self, app: Flask, trial_app_chat, account): + def test_success(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_tts", return_value={"audio": "base64_data"}), patch.object(module.RecommendedAppService, "add_trial_app_record"), ): - result = method(api, trial_app_chat) + result = method(api, account, trial_app_chat) assert result == {"audio": "base64_data"} - def test_app_config_broken(self, app: Flask, trial_app_chat, account): + def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_tts", @@ -831,15 +787,14 @@ class TestTrialChatTextApi: ), ): with pytest.raises(module.AppUnavailableError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_provider_not_support(self, app: Flask, trial_app_chat, account): + def test_provider_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_tts", @@ -847,15 +802,14 @@ class TestTrialChatTextApi: ), ): with pytest.raises(module.ProviderNotSupportSpeechToTextError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_audio_too_large(self, app: Flask, trial_app_chat, account): + def test_audio_too_large(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_tts", @@ -863,15 +817,14 @@ class TestTrialChatTextApi: ), ): with pytest.raises(module.AudioTooLargeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_no_audio_uploaded(self, app: Flask, trial_app_chat, account): + def test_no_audio_uploaded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_tts", @@ -879,59 +832,55 @@ class TestTrialChatTextApi: ), ): with pytest.raises(module.NoAudioUploadedError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_provider_not_init(self, app: Flask, trial_app_chat, account): + def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_tts", side_effect=ProviderTokenNotInitError("test")), ): with pytest.raises(ProviderNotInitializeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_quota_exceeded(self, app: Flask, trial_app_chat, account): + def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_tts", side_effect=QuotaExceededError()), ): with pytest.raises(ProviderQuotaExceededError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_model_not_support(self, app: Flask, trial_app_chat, account): + def test_model_not_support(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_tts", side_effect=ModelCurrentlyNotSupportError()), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_invoke_error(self, app: Flask, trial_app_chat, account): + def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object(module.AudioService, "transcript_tts", side_effect=InvokeError("test error")), ): with pytest.raises(CompletionRequestError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) class TestTrialAppWorkflowTaskStopApi: - def test_not_workflow_app(self, app: Flask, trial_app_chat): + def test_not_workflow_app(self, app: Flask, trial_app_chat: MagicMock) -> None: api = module.TrialAppWorkflowTaskStopApi() method = unwrap(api.post) @@ -939,14 +888,13 @@ class TestTrialAppWorkflowTaskStopApi: with pytest.raises(NotWorkflowAppError): method(api, trial_app_chat, str(uuid4())) - def test_success(self, app: Flask, trial_app_workflow, account): + def test_success(self, app: Flask, trial_app_workflow: MagicMock) -> None: api = module.TrialAppWorkflowTaskStopApi() method = unwrap(api.post) task_id = str(uuid4()) with ( app.test_request_context("/"), - patch.object(module, "current_user", account), patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, ): @@ -958,7 +906,7 @@ class TestTrialAppWorkflowTaskStopApi: class TestTrialSitApi: - def test_no_site(self, app: Flask): + def test_no_site(self, app: Flask) -> None: api = module.TrialSitApi() method = unwrap(api.get) app_model = MagicMock() @@ -969,7 +917,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_archived_tenant(self, app: Flask): + def test_archived_tenant(self, app: Flask) -> None: api = module.TrialSitApi() method = unwrap(api.get) @@ -984,7 +932,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_success(self, app: Flask): + def test_success(self, app: Flask) -> None: api = module.TrialSitApi() method = unwrap(api.get) @@ -1009,18 +957,16 @@ class TestTrialSitApi: class TestTrialChatAudioApiExceptionHandlers: - def test_provider_not_init(self, app: Flask, trial_app_chat, account): + def test_provider_not_init(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -1028,20 +974,18 @@ class TestTrialChatAudioApiExceptionHandlers: ), ): with pytest.raises(ProviderNotInitializeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_quota_exceeded(self, app: Flask, trial_app_chat, account): + def test_quota_exceeded(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -1049,20 +993,18 @@ class TestTrialChatAudioApiExceptionHandlers: ), ): with pytest.raises(ProviderQuotaExceededError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_invoke_error(self, app: Flask, trial_app_chat, account): + def test_invoke_error(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatAudioApi() method = unwrap(api.post) - file_data = BytesIO(b"fake audio data") - file_data.filename = "test.wav" + file_data = _file_data() with ( app.test_request_context( "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" ), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_asr", @@ -1070,17 +1012,16 @@ class TestTrialChatAudioApiExceptionHandlers: ), ): with pytest.raises(CompletionRequestError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) class TestTrialChatTextApiExceptionHandlers: - def test_app_config_broken(self, app: Flask, trial_app_chat, account): + def test_app_config_broken(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_tts", @@ -1088,15 +1029,14 @@ class TestTrialChatTextApiExceptionHandlers: ), ): with pytest.raises(module.AppUnavailableError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) - def test_unsupported_audio_type(self, app: Flask, trial_app_chat, account): + def test_unsupported_audio_type(self, app: Flask, trial_app_chat: MagicMock, account: Account) -> None: api = module.TrialChatTextApi() method = unwrap(api.post) with ( app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), - patch.object(module, "current_user", account), patch.object( module.AudioService, "transcript_tts", @@ -1104,4 +1044,4 @@ class TestTrialChatTextApiExceptionHandlers: ), ): with pytest.raises(module.UnsupportedAudioTypeError): - method(api, trial_app_chat) + method(api, account, trial_app_chat) diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py index d92454d6d84..3e804583a6e 100644 --- a/api/tests/unit_tests/controllers/console/test_feature.py +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -1,32 +1,36 @@ +from inspect import unwrap + from pytest_mock import MockerFixture -from werkzeug.exceptions import Unauthorized + +from models import Account +from services.feature_service import FeatureModel, LimitationModel, SystemFeatureModel -def unwrap(func): - """ - Recursively unwrap decorated functions. - """ - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func +def make_account() -> Account: + account = Account(name="Alice", email="alice@example.com") + account.id = "account-1" + return account class TestFeatureApi: def test_get_tenant_features_success(self, mocker: MockerFixture): from controllers.console.feature import FeatureApi + features = FeatureModel( + knowledge_rate_limit=42, + vector_space=LimitationModel(size=1, limit=2), + ) get_features = mocker.patch("controllers.console.feature.FeatureService.get_features") - get_features.return_value.model_dump.return_value = { - "features": {"feature_a": True}, - "vector_space": {"size": 1, "limit": 2}, - } + get_features.return_value = features api = FeatureApi() raw_get = unwrap(FeatureApi.get) result = raw_get(api, "tenant_123") - assert result == {"features": {"feature_a": True}} + expected = features.model_dump() + expected.pop("vector_space") + assert result == expected get_features.assert_called_once_with("tenant_123", exclude_vector_space=True) @@ -35,7 +39,7 @@ class TestFeatureVectorSpaceApi: from controllers.console.feature import FeatureVectorSpaceApi get_vector_space = mocker.patch("controllers.console.feature.FeatureService.get_vector_space") - get_vector_space.return_value.model_dump.return_value = {"size": 5120, "limit": 20480} + get_vector_space.return_value = LimitationModel(size=5120, limit=20480) api = FeatureVectorSpaceApi() @@ -85,22 +89,23 @@ class TestSystemFeatureApi: from controllers.console.feature import SystemFeatureApi - fake_user = mocker.Mock() - fake_user.is_authenticated = True - - mocker.patch( - "controllers.console.feature.current_user", - fake_user, + account = make_account() + current_account = mocker.patch( + "controllers.console.feature.current_account_with_tenant_optional", + return_value=(account, "tenant-123"), + ) + system_features = SystemFeatureModel(is_allow_register=True) + get_system_features = mocker.patch( + "controllers.console.feature.FeatureService.get_system_features", + return_value=system_features, ) - - mocker.patch( - "controllers.console.feature.FeatureService.get_system_features" - ).return_value.model_dump.return_value = {"features": {"sys_feature": True}} api = SystemFeatureApi() result = api.get() - assert result == {"features": {"sys_feature": True}} + assert result == system_features.model_dump() + current_account.assert_called_once_with() + get_system_features.assert_called_once_with(is_authenticated=True) def test_get_system_features_unauthenticated(self, mocker: MockerFixture): """ @@ -109,19 +114,19 @@ class TestSystemFeatureApi: from controllers.console.feature import SystemFeatureApi - fake_user = mocker.Mock() - type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized()) - - mocker.patch( - "controllers.console.feature.current_user", - fake_user, + current_account = mocker.patch( + "controllers.console.feature.current_account_with_tenant_optional", + return_value=(None, None), + ) + system_features = SystemFeatureModel(is_allow_register=False) + get_system_features = mocker.patch( + "controllers.console.feature.FeatureService.get_system_features", + return_value=system_features, ) - - mocker.patch( - "controllers.console.feature.FeatureService.get_system_features" - ).return_value.model_dump.return_value = {"features": {"sys_feature": False}} api = SystemFeatureApi() result = api.get() - assert result == {"features": {"sys_feature": False}} + assert result == system_features.model_dump() + current_account.assert_called_once_with() + get_system_features.assert_called_once_with(is_authenticated=False) diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index fb2ef55fe80..937505dab28 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -1,3 +1,4 @@ +from typing import override from unittest.mock import MagicMock, patch import pytest @@ -36,6 +37,7 @@ class MockUser(UserMixin): self.id = user_id self.current_tenant_id = "tenant123" + @override def get_id(self) -> str: return self.id @@ -210,6 +212,7 @@ class TestModelValidationInjection: Handler().post() assert exc_info.value.code == 422 + assert exc_info.value.description is not None assert "count" in exc_info.value.description diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 4809cc0e8a2..f47de0c8d50 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -9,21 +9,21 @@ Strategy: - ``HitTestingApi.post`` is decorated with ``@cloud_edition_billing_rate_limit_check`` which preserves ``__wrapped__``. We call ``post.__wrapped__(self, ...)`` to skip the billing decorator and test the business logic directly. -- Base-class methods (``get_and_validate_dataset``, ``perform_hit_testing``) read - ``current_user`` from ``controllers.console.datasets.hit_testing_base``, so we - patch it there. +- ``validate_dataset_token`` installs the tenant owner account into Flask-Login's + request context before calling the handler, so direct method-call tests install + the same concrete account on ``g._login_user``. """ import uuid -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest -from flask import Flask +from flask import Flask, g from werkzeug.exceptions import Forbidden, NotFound import services from controllers.service_api.dataset.hit_testing import HitTestingApi, HitTestingPayload -from models.account import Account +from models.account import Account, Tenant, TenantAccountRole from models.dataset import Dataset from services.entities.knowledge_entities.knowledge_entities import RetrievalModel @@ -131,13 +131,21 @@ class TestHitTestingApiPost: def _dataset(dataset_id: str, tenant_id: str) -> Dataset: return Dataset(id=dataset_id, tenant_id=tenant_id, name="Dataset", created_by="account-1") + @staticmethod + def _account(tenant_id: str) -> Account: + account = Account(name="Service API", email="service-api@example.com") + account.id = "account-1" + tenant = Tenant(name="Tenant") + tenant.id = tenant_id + account._current_tenant = tenant + account.role = TenantAccountRole.OWNER + return account + @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") - @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) def test_post_success( self, - mock_current_user, mock_dataset_svc, mock_hit_svc, mock_ns, @@ -148,6 +156,7 @@ class TestHitTestingApiPost: tenant_id = str(uuid.uuid4()) mock_dataset = self._dataset(dataset_id, tenant_id) + account = self._account(tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None @@ -158,6 +167,8 @@ class TestHitTestingApiPost: mock_ns.payload = {"query": "test query"} with app.test_request_context(): + # TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack + g._login_user = account api = HitTestingApi() # Skip billing decorator via __wrapped__ response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) @@ -168,10 +179,8 @@ class TestHitTestingApiPost: @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") - @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) def test_post_with_retrieval_model( self, - mock_current_user, mock_dataset_svc, mock_hit_svc, mock_ns, @@ -182,6 +191,7 @@ class TestHitTestingApiPost: tenant_id = str(uuid.uuid4()) mock_dataset = self._dataset(dataset_id, tenant_id) + account = self._account(tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None @@ -204,6 +214,8 @@ class TestHitTestingApiPost: } with app.test_request_context(): + # TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack + g._login_user = account api = HitTestingApi() response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) @@ -218,10 +230,8 @@ class TestHitTestingApiPost: @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") - @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) def test_post_preserves_retrieval_model_metadata_filtering_conditions( self, - mock_current_user, mock_dataset_svc, mock_hit_svc, mock_ns, @@ -232,6 +242,7 @@ class TestHitTestingApiPost: tenant_id = str(uuid.uuid4()) mock_dataset = self._dataset(dataset_id, tenant_id) + account = self._account(tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None @@ -260,6 +271,8 @@ class TestHitTestingApiPost: } with app.test_request_context(): + # TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack + g._login_user = account api = HitTestingApi() HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) @@ -270,10 +283,8 @@ class TestHitTestingApiPost: @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.HitTestingService") @patch("controllers.console.datasets.hit_testing_base.DatasetService") - @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) def test_post_prepares_nullable_list_fields( self, - mock_current_user, mock_dataset_svc, mock_hit_svc, mock_ns, @@ -284,6 +295,7 @@ class TestHitTestingApiPost: tenant_id = str(uuid.uuid4()) mock_dataset = self._dataset(dataset_id, tenant_id) + account = self._account(tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None @@ -297,6 +309,8 @@ class TestHitTestingApiPost: mock_ns.payload = {"query": "legacy query"} with app.test_request_context(): + # TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack + g._login_user = account api = HitTestingApi() response = HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) @@ -312,10 +326,8 @@ class TestHitTestingApiPost: @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") - @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) def test_post_dataset_not_found( self, - mock_current_user, mock_dataset_svc, mock_ns, app: Flask, @@ -323,21 +335,22 @@ class TestHitTestingApiPost: """Test hit testing with non-existent dataset.""" dataset_id = str(uuid.uuid4()) tenant_id = str(uuid.uuid4()) + account = self._account(tenant_id) mock_dataset_svc.get_dataset.return_value = None mock_ns.payload = {"query": "test query"} with app.test_request_context(): + # TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack + g._login_user = account api = HitTestingApi() with pytest.raises(NotFound): HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") - @patch("controllers.console.datasets.hit_testing_base.current_user", new_callable=lambda: Mock(spec=Account)) def test_post_no_dataset_permission( self, - mock_current_user, mock_dataset_svc, mock_ns, app: Flask, @@ -347,6 +360,7 @@ class TestHitTestingApiPost: tenant_id = str(uuid.uuid4()) mock_dataset = self._dataset(dataset_id, tenant_id) + account = self._account(tenant_id) mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.side_effect = services.errors.account.NoPermissionError( @@ -355,6 +369,8 @@ class TestHitTestingApiPost: mock_ns.payload = {"query": "test query"} with app.test_request_context(): + # TODO: the service APIs are NOT migrated yet, so we have to do the very dirty hack + g._login_user = account api = HitTestingApi() with pytest.raises(Forbidden): HitTestingApi.post.__wrapped__(api, tenant_id, dataset_id) diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index 2bf22128448..8b32e448d64 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -1,15 +1,15 @@ -from types import SimpleNamespace +from typing import cast from unittest.mock import MagicMock import pytest from flask import Flask, Response, g -from flask_login import UserMixin from pytest_mock import MockerFixture +from werkzeug.exceptions import Unauthorized import libs.login as login_module from extensions.ext_login import DifyLoginManager from libs.login import current_user -from models.account import Account +from models.account import Account, Tenant @pytest.fixture @@ -23,7 +23,7 @@ def protected_view(): return _protected_view -class MockUser(UserMixin): +class MockUser: """Mock user class for testing.""" def __init__(self, id: str, is_authenticated: bool = True): @@ -35,6 +35,22 @@ class MockUser(UserMixin): return self._is_authenticated +class LoginManagerStub: + def __init__(self, unauthorized_response: Response) -> None: + self._unauthorized_response = unauthorized_response + + def unauthorized(self) -> Response: + return self._unauthorized_response + + +def _login_manager(app: Flask) -> DifyLoginManager: + return cast(DifyLoginManager, app.__dict__["login_manager"]) + + +def _unauthorized_mock(app: Flask) -> MagicMock: + return cast(MagicMock, _login_manager(app).unauthorized) + + @pytest.fixture def login_app(mocker: MockerFixture) -> Flask: app = Flask(__name__) @@ -95,7 +111,7 @@ class TestLoginRequired: assert result == "Protected content" resolve_user.assert_called_once_with() - login_app.login_manager.unauthorized.assert_not_called() + _unauthorized_mock(login_app).assert_not_called() @pytest.mark.parametrize( ("resolved_user", "description"), @@ -120,11 +136,11 @@ class TestLoginRequired: with login_app.test_request_context(): result = protected_view() - assert result is login_app.login_manager.unauthorized.return_value, description + assert result is _unauthorized_mock(login_app).return_value, description assert isinstance(result, Response) assert result.status_code == 401 resolve_user.assert_called_once_with() - login_app.login_manager.unauthorized.assert_called_once_with() + _unauthorized_mock(login_app).assert_called_once_with() csrf_check.assert_not_called() def test_unauthorized_access_propagates_response_object( @@ -138,9 +154,7 @@ class TestLoginRequired: """Test that unauthorized responses are propagated as Flask Response objects.""" resolve_user = resolve_current_user(None) response = Response("Unauthorized", status=401, content_type="application/json") - mocker.patch.object( - login_module, "_get_login_manager", return_value=SimpleNamespace(unauthorized=lambda: response) - ) + mocker.patch.object(login_module, "_get_login_manager", return_value=LoginManagerStub(response)) with login_app.test_request_context(): result = protected_view() @@ -177,7 +191,7 @@ class TestLoginRequired: assert result == "Protected content" resolve_user.assert_not_called() csrf_check.assert_not_called() - login_app.login_manager.unauthorized.assert_not_called() + _unauthorized_mock(login_app).assert_not_called() class TestGetUser: @@ -191,6 +205,7 @@ class TestGetUser: g._login_user = mock_user user = login_module._get_user() assert user == mock_user + assert user is not None assert user.id == "test_user" def test_get_user_loads_user_if_not_in_g(self, login_app: Flask, mocker: MockerFixture): @@ -201,7 +216,7 @@ class TestGetUser: g._login_user = mock_user load_user = mocker.patch.object( - login_app.login_manager, + _login_manager(login_app), "load_user_from_request_context", side_effect=load_user_from_request_context, ) @@ -244,7 +259,9 @@ class TestCurrentAccountWithTenant: def test_returns_account_and_tenant_id(self, mocker: MockerFixture): account = Account(name="Test User", email="test@example.com") - account._current_tenant = SimpleNamespace(id="tenant-123") + tenant = Tenant(name="Test Tenant") + tenant.id = "tenant-123" + account._current_tenant = tenant current_user_proxy = mocker.Mock() current_user_proxy._get_current_object.return_value = account mocker.patch.object(login_module, "current_user", new=current_user_proxy) @@ -267,3 +284,58 @@ class TestCurrentAccountWithTenant: with pytest.raises(AssertionError, match="tenant information should be loaded"): login_module.current_account_with_tenant() + + +class TestCurrentAccountWithTenantOptional: + """Test cases for optional current account resolution.""" + + def test_returns_account_and_tenant_id_for_authenticated_account(self, mocker: MockerFixture) -> None: + account = Account(name="Test User", email="test@example.com") + tenant = Tenant(name="Test Tenant") + tenant.id = "tenant-123" + account._current_tenant = tenant + mocker.patch.object(login_module, "_resolve_current_user", return_value=account) + + user, tenant_id = login_module.current_account_with_tenant_optional() + + assert user is account + assert tenant_id == "tenant-123" + + def test_returns_none_pair_when_request_loader_raises_unauthorized(self, mocker: MockerFixture) -> None: + mocker.patch.object(login_module, "_resolve_current_user", side_effect=Unauthorized()) + + user, tenant_id = login_module.current_account_with_tenant_optional() + + assert user is None + assert tenant_id is None + + def test_returns_none_pair_when_resolved_user_is_not_account(self, mocker: MockerFixture) -> None: + mocker.patch.object(login_module, "_resolve_current_user", return_value=MockUser("end-user")) + + user, tenant_id = login_module.current_account_with_tenant_optional() + + assert user is None + assert tenant_id is None + + +class TestResolveTenantIdFallback: + """Test cases for tenant-only fallback helper.""" + + def test_returns_provided_tenant_id_without_current_user_lookup(self, mocker: MockerFixture) -> None: + current_account_with_tenant = mocker.patch.object(login_module, "current_account_with_tenant") + + tenant_id = login_module.resolve_tenant_id_fallback("tenant-123") + + assert tenant_id == "tenant-123" + current_account_with_tenant.assert_not_called() + + def test_falls_back_to_current_account_tenant(self, mocker: MockerFixture) -> None: + account = Account(name="Test User", email="test@example.com") + tenant = Tenant(name="Test Tenant") + tenant.id = "tenant-123" + account._current_tenant = tenant + mocker.patch.object(login_module, "current_account_with_tenant", return_value=(account, tenant.id)) + + tenant_id = login_module.resolve_tenant_id_fallback() + + assert tenant_id == "tenant-123" diff --git a/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_customized_retrieval.py b/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_customized_retrieval.py index 647a2f0bfc9..106b959a78b 100644 --- a/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_customized_retrieval.py +++ b/api/tests/unit_tests/services/rag_pipeline/pipeline_template/test_customized_retrieval.py @@ -5,10 +5,6 @@ from services.rag_pipeline.pipeline_template.pipeline_template_type import Pipel def test_get_pipeline_templates(mocker) -> None: - mocker.patch( - "services.rag_pipeline.pipeline_template.customized.customized_retrieval.current_account_with_tenant", - return_value=("account-id", "tenant-id"), - ) customized_template = SimpleNamespace( id="tpl-1", name="Custom Template", @@ -27,7 +23,7 @@ def test_get_pipeline_templates(mocker) -> None: ) retrieval = CustomizedPipelineTemplateRetrieval() - result = retrieval.get_pipeline_templates("en-US") + result = retrieval.get_pipeline_templates("en-US", "tenant-id") assert retrieval.get_type() == PipelineTemplateType.CUSTOMIZED assert result == { diff --git a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py index efb79aadde2..b255595047d 100644 --- a/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py +++ b/api/tests/unit_tests/services/rag_pipeline/test_rag_pipeline_service.py @@ -1,9 +1,14 @@ +import json import time +from datetime import datetime from types import SimpleNamespace import pytest from sqlalchemy.orm import sessionmaker +from models import Account, Tenant +from models.dataset import Dataset, Pipeline, PipelineCustomizedTemplate, PipelineRecommendedPlugin +from models.workflow import Workflow from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, PipelineTemplateInfoEntity from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -25,6 +30,89 @@ class MockRepo: pass +def _make_account(account_id: str = "u1", tenant_id: str = "t1") -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + tenant = Tenant(name="Test Tenant") + tenant.id = tenant_id + account._current_tenant = tenant + return account + + +def _make_pipeline( + *, + pipeline_id: str = "p1", + tenant_id: str = "t1", + workflow_id: str | None = None, + is_published: bool = False, +) -> Pipeline: + pipeline = Pipeline(tenant_id=tenant_id, name="Test Pipeline", description="test") + pipeline.id = pipeline_id + pipeline.workflow_id = workflow_id + pipeline.is_published = is_published + return pipeline + + +def _make_workflow( + *, + workflow_id: str = "wf-1", + tenant_id: str = "t1", + app_id: str = "p1", + graph: dict[str, object] | None = None, + features: dict[str, object] | None = None, + created_by: str = "u1", +) -> Workflow: + workflow = Workflow( + id=workflow_id, + tenant_id=tenant_id, + app_id=app_id, + type="workflow", + version="draft", + marked_name="", + marked_comment="", + graph=json.dumps(graph or {"nodes": []}), + features=json.dumps(features or {}), + created_by=created_by, + created_at=datetime(2024, 1, 1), + updated_by=None, + updated_at=datetime(2024, 1, 1), + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + return workflow + + +def _make_dataset(*, dataset_id: str = "d1", pipeline_id: str = "p1", tenant_id: str = "t1") -> Dataset: + dataset = Dataset( + id=dataset_id, + tenant_id=tenant_id, + name="Test Dataset", + created_by="u1", + ) + dataset.pipeline_id = pipeline_id + return dataset + + +def _make_customized_template() -> PipelineCustomizedTemplate: + return PipelineCustomizedTemplate( + tenant_id="t1", + name="old", + description="old", + chunk_structure="paragraph", + icon={}, + position=1, + yaml_content="", + install_count=0, + language="en-US", + created_by="u1", + ) + + +def _make_recommended_plugin(plugin_id: str) -> PipelineRecommendedPlugin: + return PipelineRecommendedPlugin(plugin_id=plugin_id, provider_name=plugin_id, type="tool", position=0, active=True) + + def test_get_pipeline_templates_fallbacks_to_builtin_for_non_english_empty_result(mocker) -> None: mocker.patch("services.rag_pipeline.rag_pipeline.dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE", "remote") @@ -74,7 +162,7 @@ def test_get_pipeline_template_detail_uses_expected_mode(mocker, template_type: def test_get_published_workflow_returns_none_when_pipeline_has_no_workflow_id(rag_pipeline_service) -> None: - pipeline = SimpleNamespace(workflow_id=None) + pipeline = _make_pipeline(workflow_id=None) result = rag_pipeline_service.get_published_workflow(pipeline) @@ -82,7 +170,7 @@ def test_get_published_workflow_returns_none_when_pipeline_has_no_workflow_id(ra def test_get_all_published_workflow_returns_empty_for_unpublished_pipeline(rag_pipeline_service) -> None: - pipeline = SimpleNamespace(workflow_id=None) + pipeline = _make_pipeline(workflow_id=None) session = SimpleNamespace() workflows, has_more = rag_pipeline_service.get_all_published_workflow( @@ -101,7 +189,7 @@ def test_get_all_published_workflow_returns_empty_for_unpublished_pipeline(rag_p def test_get_all_published_workflow_applies_limit_and_has_more(rag_pipeline_service) -> None: scalars_result = SimpleNamespace(all=lambda: ["wf1", "wf2", "wf3"]) session = SimpleNamespace(scalars=lambda stmt: scalars_result) - pipeline = SimpleNamespace(id="pipeline-1", workflow_id="wf-live") + pipeline = _make_pipeline(pipeline_id="pipeline-1", workflow_id="wf-live") workflows, has_more = rag_pipeline_service.get_all_published_workflow( session=session, @@ -133,8 +221,8 @@ def test_sync_draft_workflow_creates_new_when_none_exists(mocker, rag_pipeline_s mocker.patch("services.rag_pipeline.rag_pipeline.db.session.flush") mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") - pipeline = SimpleNamespace(tenant_id="t1", id="p1", workflow_id=None) - account = SimpleNamespace(id="u1") + pipeline = _make_pipeline(workflow_id=None) + account = _make_account() result = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -153,11 +241,11 @@ def test_sync_draft_workflow_creates_new_when_none_exists(mocker, rag_pipeline_s def test_sync_draft_workflow_raises_on_hash_mismatch(mocker, rag_pipeline_service) -> None: from services.errors.app import WorkflowHashNotEqualError - existing_wf = SimpleNamespace(unique_hash="hash-old") + existing_wf = _make_workflow(graph={"nodes": [{"id": "old"}]}) mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=existing_wf) - pipeline = SimpleNamespace(tenant_id="t1", id="p1") - account = SimpleNamespace(id="u1") + pipeline = _make_pipeline() + account = _make_account() with pytest.raises(WorkflowHashNotEqualError): rag_pipeline_service.sync_draft_workflow( @@ -184,8 +272,8 @@ def test_sync_draft_workflow_updates_existing(mocker, rag_pipeline_service) -> N mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=existing_wf) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") - pipeline = SimpleNamespace(tenant_id="t1", id="p1") - account = SimpleNamespace(id="u1") + pipeline = _make_pipeline() + account = _make_account() result = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, @@ -275,7 +363,7 @@ def test_get_rag_pipeline_paginate_workflow_runs_delegates(mocker, rag_pipeline_ repo_mock.get_paginated_workflow_runs.return_value = expected rag_pipeline_service._workflow_run_repo = repo_mock - pipeline = SimpleNamespace(tenant_id="t1", id="p1") + pipeline = _make_pipeline() result = rag_pipeline_service.get_rag_pipeline_paginate_workflow_runs(pipeline, {"limit": 10, "last_id": "abc"}) assert result is expected @@ -297,7 +385,7 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) - repo_mock.get_workflow_run_by_id.return_value = expected rag_pipeline_service._workflow_run_repo = repo_mock - pipeline = SimpleNamespace(tenant_id="t1", id="p1") + pipeline = _make_pipeline() result = rag_pipeline_service.get_rag_pipeline_workflow_run(pipeline, "run-1") assert result is expected @@ -310,14 +398,14 @@ def test_get_rag_pipeline_workflow_run_delegates(mocker, rag_pipeline_service) - def test_is_workflow_exist_returns_true_when_draft_exists(mocker, rag_pipeline_service) -> None: mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=1) - pipeline = SimpleNamespace(tenant_id="t1", id="p1") + pipeline = _make_pipeline() assert rag_pipeline_service.is_workflow_exist(pipeline) is True def test_is_workflow_exist_returns_false_when_no_draft(mocker, rag_pipeline_service) -> None: mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=0) - pipeline = SimpleNamespace(tenant_id="t1", id="p1") + pipeline = _make_pipeline() assert rag_pipeline_service.is_workflow_exist(pipeline) is False @@ -635,17 +723,10 @@ def test_get_second_step_parameters_success(mocker, rag_pipeline_service) -> Non def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_service) -> None: - from models.dataset import Pipeline - # 1. Setup mocks - pipeline = mocker.Mock(spec=Pipeline) - pipeline.id = "p1" - pipeline.tenant_id = "t1" - pipeline.workflow_id = "wf-1" - pipeline.is_published = True + pipeline = _make_pipeline(workflow_id="wf-1", is_published=True) - workflow = mocker.Mock() - workflow.id = "wf-1" + workflow = _make_workflow(workflow_id="wf-1") # Mock db itself to avoid app context errors mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") @@ -656,8 +737,8 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi mock_db.session.scalar.side_effect = [None, 5] # Mock retrieve_dataset - dataset = mocker.Mock() - pipeline.retrieve_dataset.return_value = dataset + dataset = _make_dataset() + dataset.chunk_structure = "paragraph" # Mock RagPipelineDslService mock_dsl_service = mocker.Mock() @@ -665,16 +746,14 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi mocker.patch("services.rag_pipeline.rag_pipeline_dsl_service.RagPipelineDslService", return_value=mock_dsl_service) # Mock Session and commit - mocker.patch("services.rag_pipeline.rag_pipeline.Session", return_value=mocker.MagicMock()) + session_factory = mocker.patch("services.rag_pipeline.rag_pipeline.sessionmaker") + session_factory.return_value.begin.return_value.__enter__.return_value.scalar.return_value = dataset - # Mock current_user - mock_user = mocker.Mock() - mock_user.id = "user-123" - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", mock_user) + account = _make_account(account_id="user-123") # 2. Run test args = {"name": "New Template", "description": "Desc", "icon_info": {"icon": "star"}, "tags": ["tag1"]} - rag_pipeline_service.publish_customized_pipeline_template("p1", args) + rag_pipeline_service.publish_customized_pipeline_template("p1", args, account, "t1") # 3. Assertions # Verify a new template was added to session or similar? @@ -687,14 +766,10 @@ def test_publish_customized_pipeline_template_success(mocker, rag_pipeline_servi def test_get_datasource_plugins_success(mocker, rag_pipeline_service) -> None: - from models.dataset import Dataset, Pipeline - # 1. Setup mocks - dataset = mocker.Mock(spec=Dataset) - dataset.pipeline_id = "p1" + dataset = _make_dataset() - pipeline = mocker.Mock(spec=Pipeline) - pipeline.id = "p1" + pipeline = _make_pipeline() workflow = mocker.Mock() workflow.graph_dict = { @@ -835,15 +910,10 @@ def test_set_datasource_variables_success(mocker, rag_pipeline_service) -> None: def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None: - from models.dataset import Pipeline - from models.workflow import Workflow - # 1. Setup mocks - pipeline = mocker.Mock(spec=Pipeline) - pipeline.id = "p1" - pipeline.tenant_id = "t1" + pipeline = _make_pipeline() - workflow = mocker.Mock(spec=Workflow) + workflow = _make_workflow() mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.session.scalar.return_value = workflow @@ -856,16 +926,10 @@ def test_get_draft_workflow_success(mocker, rag_pipeline_service) -> None: def test_get_published_workflow_success(mocker, rag_pipeline_service) -> None: - from models.dataset import Pipeline - from models.workflow import Workflow - # 1. Setup mocks - pipeline = mocker.Mock(spec=Pipeline) - pipeline.id = "p1" - pipeline.tenant_id = "t1" - pipeline.workflow_id = "wf-pub" + pipeline = _make_pipeline(workflow_id="wf-pub") - workflow = mocker.Mock(spec=Workflow) + workflow = _make_workflow(workflow_id="wf-pub") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.session.scalar.return_value = workflow @@ -896,8 +960,8 @@ def test_get_default_block_config_success(rag_pipeline_service) -> None: def test_publish_workflow_raises_when_draft_workflow_missing(mocker, rag_pipeline_service) -> None: session = mocker.Mock() session.scalar.return_value = None - pipeline = SimpleNamespace(id="p1", tenant_id="t1") - account = SimpleNamespace(id="u1") + pipeline = _make_pipeline() + account = _make_account() with pytest.raises(ValueError, match="No valid workflow found"): rag_pipeline_service.publish_workflow(session=session, pipeline=pipeline, account=account) @@ -929,8 +993,8 @@ def test_get_default_block_config_injects_http_request_filter(mocker, rag_pipeli def test_run_draft_workflow_node_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: - pipeline = SimpleNamespace(id="p1", tenant_id="t1") - account = SimpleNamespace(id="u1") + pipeline = _make_pipeline() + account = _make_account() mocker.patch.object(rag_pipeline_service, "get_draft_workflow", return_value=None) with pytest.raises(ValueError, match="Workflow not initialized"): @@ -939,8 +1003,8 @@ def test_run_draft_workflow_node_raises_when_workflow_missing(mocker, rag_pipeli def test_run_draft_workflow_node_saves_execution_and_variables(mocker, rag_pipeline_service) -> None: mocker.patch("services.rag_pipeline.rag_pipeline.db", mocker.Mock(engine=mocker.Mock())) - pipeline = SimpleNamespace(id="p1", tenant_id="t1") - account = SimpleNamespace(id="u1") + pipeline = _make_pipeline() + account = _make_account() draft_workflow = mocker.Mock(id="wf-1") draft_workflow.get_node_config_by_id.return_value = {"id": "node-1"} draft_workflow.get_enclosing_node_type_and_id.return_value = ("loop", "enclosing-node") @@ -1163,11 +1227,11 @@ def test_get_second_step_parameters_handles_string_and_list_variable_references( def test_get_rag_pipeline_workflow_run_node_executions_empty_when_run_missing(mocker, rag_pipeline_service) -> None: - pipeline = SimpleNamespace(id="p1", tenant_id="t1") + pipeline = _make_pipeline() mocker.patch.object(rag_pipeline_service, "get_rag_pipeline_workflow_run", return_value=None) result = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( - pipeline=pipeline, run_id="run-1", user=SimpleNamespace(id="u1") + pipeline=pipeline, run_id="run-1", user=_make_account() ) assert result == [] @@ -1175,14 +1239,14 @@ def test_get_rag_pipeline_workflow_run_node_executions_empty_when_run_missing(mo def test_get_rag_pipeline_workflow_run_node_executions_returns_sorted_executions(mocker, rag_pipeline_service) -> None: mocker.patch("services.rag_pipeline.rag_pipeline.db", mocker.Mock(engine=mocker.Mock())) - pipeline = SimpleNamespace(id="p1", tenant_id="t1") + pipeline = _make_pipeline() mocker.patch.object(rag_pipeline_service, "get_rag_pipeline_workflow_run", return_value=SimpleNamespace(id="run-1")) repo = mocker.Mock() repo.get_db_models_by_workflow_run.return_value = ["n1", "n2"] mocker.patch("services.rag_pipeline.rag_pipeline.SQLAlchemyWorkflowNodeExecutionRepository", return_value=repo) result = rag_pipeline_service.get_rag_pipeline_workflow_run_node_executions( - pipeline=pipeline, run_id="run-1", user=SimpleNamespace(id="u1") + pipeline=pipeline, run_id="run-1", user=_make_account() ) assert result == ["n1", "n2"] @@ -1192,7 +1256,7 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.session.scalars.return_value.all.return_value = [] - result = rag_pipeline_service.get_recommended_plugins("all") + result = rag_pipeline_service.get_recommended_plugins("all", _make_account(), "t1") assert result == { "installed_recommended_plugins": [], @@ -1201,11 +1265,10 @@ def test_get_recommended_plugins_returns_empty_when_no_active_plugins(mocker, ra def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_pipeline_service) -> None: - plugin_a = SimpleNamespace(plugin_id="plugin-a") - plugin_b = SimpleNamespace(plugin_id="plugin-b") + plugin_a = _make_recommended_plugin("plugin-a") + plugin_b = _make_recommended_plugin("plugin-b") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.session.scalars.return_value.all.return_value = [plugin_a, plugin_b] - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch( "services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[SimpleNamespace(plugin_id="plugin-a", to_dict=lambda: {"plugin_id": "plugin-a"})], @@ -1215,7 +1278,7 @@ def test_get_recommended_plugins_returns_installed_and_uninstalled(mocker, rag_p return_value=[{"plugin_id": "plugin-b", "name": "Plugin B"}], ) - result = rag_pipeline_service.get_recommended_plugins("custom") + result = rag_pipeline_service.get_recommended_plugins("custom", _make_account(), "t1") assert result["installed_recommended_plugins"] == [{"plugin_id": "plugin-a"}] assert result["uninstalled_recommended_plugins"] == [{"plugin_id": "plugin-b", "name": "Plugin B"}] @@ -1229,8 +1292,8 @@ def test_get_node_last_run_delegates_to_repository(mocker, rag_pipeline_service) "services.rag_pipeline.rag_pipeline.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository", return_value=repo, ) - pipeline = SimpleNamespace(id="p1", tenant_id="t1") - workflow = SimpleNamespace(id="wf1") + pipeline = _make_pipeline() + workflow = _make_workflow(workflow_id="wf1") result = rag_pipeline_service.get_node_last_run(pipeline, workflow, "node-1") @@ -1572,15 +1635,17 @@ def test_publish_customized_pipeline_template_raises_for_missing_pipeline(mocker mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=None) with pytest.raises(ValueError, match="Pipeline not found"): - rag_pipeline_service.publish_customized_pipeline_template("p1", {}) + rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1") def test_publish_customized_pipeline_template_raises_for_missing_workflow_id(mocker, rag_pipeline_service) -> None: - pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id=None) + pipeline = _make_pipeline(workflow_id=None) mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", return_value=pipeline) with pytest.raises(ValueError, match="Pipeline workflow not found"): - rag_pipeline_service.publish_customized_pipeline_template("p1", {"name": "template-name"}) + rag_pipeline_service.publish_customized_pipeline_template( + "p1", {"name": "template-name"}, _make_account(), "t1" + ) def test_get_pipeline_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: @@ -1630,13 +1695,12 @@ def test_get_pipeline_templates_builtin_en_us_no_fallback(mocker) -> None: def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> None: - template = SimpleNamespace(name="old", description="old", icon={}, updated_by=None) + template = _make_customized_template() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", return_value=template) commit = mocker.patch("services.rag_pipeline.rag_pipeline.db.session.commit") - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) info = PipelineTemplateInfoEntity(name="", description="updated", icon_info=IconInfo(icon="i")) - result = RagPipelineService.update_customized_pipeline_template("tpl-1", info) + result = RagPipelineService.update_customized_pipeline_template("tpl-1", info, _make_account(), "t1") assert result.description == "updated" commit.assert_called_once() @@ -1644,7 +1708,7 @@ def test_update_customized_pipeline_template_commits_when_name_empty(mocker) -> def test_get_all_published_workflow_without_filters_has_no_more(rag_pipeline_service) -> None: session = SimpleNamespace(scalars=lambda stmt: SimpleNamespace(all=lambda: ["wf1"])) - pipeline = SimpleNamespace(id="p1", workflow_id="wf-live") + pipeline = _make_pipeline(workflow_id="wf-live") workflows, has_more = rag_pipeline_service.get_all_published_workflow( session=session, @@ -1856,38 +1920,34 @@ def test_run_free_workflow_node_delegates_to_handle_result(mocker, rag_pipeline_ def test_publish_customized_pipeline_template_raises_when_workflow_missing(mocker, rag_pipeline_service) -> None: - pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") + pipeline = _make_pipeline(workflow_id="wf-1") mocker.patch("services.rag_pipeline.rag_pipeline.db.session.get", side_effect=[pipeline, None]) with pytest.raises(ValueError, match="Workflow not found"): - rag_pipeline_service.publish_customized_pipeline_template("p1", {}) + rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1") def test_publish_customized_pipeline_template_raises_when_dataset_missing(mocker, rag_pipeline_service) -> None: - pipeline = SimpleNamespace(id="p1", tenant_id="t1", workflow_id="wf-1") - workflow = SimpleNamespace(id="wf-1") + pipeline = _make_pipeline(workflow_id="wf-1") + workflow = _make_workflow(workflow_id="wf-1") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.engine = mocker.Mock() mock_db.session.get.side_effect = [pipeline, workflow] - session_ctx = mocker.MagicMock() - session_ctx.__enter__.return_value = SimpleNamespace() - session_ctx.__exit__.return_value = False - mocker.patch("services.rag_pipeline.rag_pipeline.Session", return_value=session_ctx) - pipeline.retrieve_dataset = lambda session: None + session_factory = mocker.patch("services.rag_pipeline.rag_pipeline.sessionmaker") + session_factory.return_value.begin.return_value.__enter__.return_value.scalar.return_value = None with pytest.raises(ValueError, match="Dataset not found"): - rag_pipeline_service.publish_customized_pipeline_template("p1", {}) + rag_pipeline_service.publish_customized_pipeline_template("p1", {}, _make_account(), "t1") def test_get_recommended_plugins_skips_manifest_when_missing(mocker, rag_pipeline_service) -> None: - plugin = SimpleNamespace(plugin_id="plugin-a") + plugin = _make_recommended_plugin("plugin-a") mock_db = mocker.patch("services.rag_pipeline.rag_pipeline.db") mock_db.session.scalars.return_value.all.return_value = [plugin] - mocker.patch("services.rag_pipeline.rag_pipeline.current_user", SimpleNamespace(id="u1", current_tenant_id="t1")) mocker.patch("services.rag_pipeline.rag_pipeline.BuiltinToolManageService.list_builtin_tools", return_value=[]) mocker.patch("services.rag_pipeline.rag_pipeline.marketplace.batch_fetch_plugin_by_ids", return_value=[]) - result = rag_pipeline_service.get_recommended_plugins("all") + result = rag_pipeline_service.get_recommended_plugins("all", _make_account(), "t1") assert result["installed_recommended_plugins"] == [] assert result["uninstalled_recommended_plugins"] == [] @@ -1918,8 +1978,8 @@ def test_retry_error_document_raises_when_workflow_missing(mocker, rag_pipeline_ def test_get_datasource_plugins_returns_empty_for_non_datasource_nodes(mocker, rag_pipeline_service) -> None: - dataset = SimpleNamespace(pipeline_id="p1") - pipeline = SimpleNamespace(id="p1", tenant_id="t1") + dataset = _make_dataset() + pipeline = _make_pipeline() workflow = SimpleNamespace( graph_dict={"nodes": [{"id": "n1", "data": {"type": "start"}}]}, rag_pipeline_variables=[] ) @@ -2079,8 +2139,8 @@ def test_set_datasource_variables_raises_when_workflow_missing(mocker, rag_pipel def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published(mocker, rag_pipeline_service) -> None: - dataset = SimpleNamespace(pipeline_id="p1") - pipeline = SimpleNamespace(id="p1", tenant_id="t1") + dataset = _make_dataset() + pipeline = _make_pipeline() workflow = SimpleNamespace( graph_dict={"nodes": [{"id": "n1", "data": {"type": "datasource", "datasource_parameters": {}}}]}, rag_pipeline_variables=[{"variable": "v1", "belong_to_node_id": "shared"}], @@ -2097,8 +2157,8 @@ def test_get_datasource_plugins_handles_empty_datasource_data_and_non_published( def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag_pipeline_service) -> None: - dataset = SimpleNamespace(pipeline_id="p1") - pipeline = SimpleNamespace(id="p1", tenant_id="t1") + dataset = _make_dataset() + pipeline = _make_pipeline() workflow = SimpleNamespace( graph_dict={ "nodes": [ @@ -2139,8 +2199,8 @@ def test_get_datasource_plugins_extracts_user_inputs_and_credentials(mocker, rag def test_get_pipeline_returns_pipeline_when_found(mocker, rag_pipeline_service) -> None: - dataset = SimpleNamespace(pipeline_id="p1") - pipeline = SimpleNamespace(id="p1") + dataset = _make_dataset() + pipeline = _make_pipeline() mocker.patch("services.rag_pipeline.rag_pipeline.db.session.scalar", side_effect=[dataset, pipeline]) result = rag_pipeline_service.get_pipeline("t1", "d1") diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index fc3a2fc416d..36ea1fac1a4 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -1,65 +1,62 @@ from pathlib import Path -from unittest.mock import Mock, create_autospec, patch +from typing import cast +from unittest.mock import Mock import pytest -from models.account import Account +from models import Account, Tenant from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService +def _make_account(account_id: str = "user-456", tenant_id: str = "tenant-123") -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + tenant = Tenant(name="Test Tenant") + tenant.id = tenant_id + account._current_tenant = tenant + return account + + class TestMetadataBugCompleteValidation: """Complete test suite to verify the metadata nullable bug and its fix.""" - def test_1_pydantic_layer_validation(self): + def test_1_pydantic_layer_validation(self) -> None: """Test Layer 1: Pydantic model validation correctly rejects None values.""" # Pydantic should reject None values for required fields with pytest.raises((ValueError, TypeError)): - MetadataArgs(type=None, name=None) + MetadataArgs(type=None, name=None) # pyrefly: ignore[bad-argument-type] with pytest.raises((ValueError, TypeError)): - MetadataArgs(type="string", name=None) + MetadataArgs(type="string", name=None) # pyrefly: ignore[bad-argument-type] with pytest.raises((ValueError, TypeError)): - MetadataArgs(type=None, name="test") + MetadataArgs(type=None, name="test") # pyrefly: ignore[bad-argument-type] # Valid values should work valid_args = MetadataArgs(type="string", name="test_name") assert valid_args.type == "string" assert valid_args.name == "test_name" - def test_2_business_logic_layer_crashes_on_none(self): + def test_2_business_logic_layer_crashes_on_none(self) -> None: """Test Layer 2: Business logic crashes when None values slip through.""" # Create mock that bypasses Pydantic validation mock_metadata_args = Mock() mock_metadata_args.name = None mock_metadata_args.type = "string" - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" - - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - # Should crash with TypeError - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args) + account = _make_account() + # Should crash with TypeError + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") # Test update method as well - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + account = _make_account() + none_name = cast(str, None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123") - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.update_metadata_name("dataset-123", "metadata-456", None) - - def test_3_database_constraints_verification(self): + def test_3_database_constraints_verification(self) -> None: """Test Layer 3: Verify database model has nullable=False constraints.""" from sqlalchemy import inspect @@ -75,7 +72,7 @@ class TestMetadataBugCompleteValidation: assert type_column.nullable is False, "type column should be nullable=False" assert name_column.nullable is False, "name column should be nullable=False" - def test_4_fixed_api_layer_rejects_null(self): + def test_4_fixed_api_layer_rejects_null(self) -> None: """Test Layer 4: Fixed API configuration properly rejects null values using Pydantic.""" with pytest.raises((ValueError, TypeError)): MetadataArgs.model_validate({"type": None, "name": None}) @@ -86,30 +83,23 @@ class TestMetadataBugCompleteValidation: with pytest.raises((ValueError, TypeError)): MetadataArgs.model_validate({"type": None, "name": "test"}) - def test_5_fixed_api_accepts_valid_values(self): + def test_5_fixed_api_accepts_valid_values(self) -> None: """Test that fixed API still accepts valid non-null values.""" args = MetadataArgs.model_validate({"type": "string", "name": "valid_name"}) assert args.type == "string" assert args.name == "valid_name" - def test_6_simulated_buggy_behavior(self): + def test_6_simulated_buggy_behavior(self) -> None: """Test simulating the original buggy behavior by bypassing Pydantic validation.""" mock_metadata_args = Mock() mock_metadata_args.name = None mock_metadata_args.type = None - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + account = _make_account() + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args) - - def test_7_end_to_end_validation_layers(self): + def test_7_end_to_end_validation_layers(self) -> None: """Test all validation layers work together correctly.""" # Layer 1: API should reject null at parameter level (with fix) # Layer 2: Pydantic should reject null at model level @@ -128,7 +118,7 @@ class TestMetadataBugCompleteValidation: assert len(metadata_args.name) <= 255 # This should not crash assert len(metadata_args.type) > 0 # This should not crash - def test_8_verify_specific_fix_locations(self): + def test_8_verify_specific_fix_locations(self) -> None: """Verify that the specific locations mentioned in bug report are fixed.""" # Read the actual files to verify fixes import os @@ -152,7 +142,7 @@ class TestMetadataBugCompleteValidation: class TestMetadataValidationSummary: """Summary tests that demonstrate the complete validation architecture.""" - def test_validation_layer_architecture(self): + def test_validation_layer_architecture(self) -> None: """Document and test the 4-layer validation architecture.""" # Layer 1: API Parameter Validation (Flask-RESTful reqparse) # - Role: First line of defense, validates HTTP request parameters diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index f43f394489a..27570a86f1a 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -1,56 +1,53 @@ -from unittest.mock import Mock, create_autospec, patch +from typing import cast +from unittest.mock import Mock import pytest -from models.account import Account +from models import Account, Tenant from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService +def _make_account(account_id: str = "user-456", tenant_id: str = "tenant-123") -> Account: + account = Account(name="Test User", email=f"{account_id}@example.com") + account.id = account_id + tenant = Tenant(name="Test Tenant") + tenant.id = tenant_id + account._current_tenant = tenant + return account + + class TestMetadataNullableBug: """Test case to reproduce the metadata nullable validation bug.""" - def test_metadata_args_with_none_values_should_fail(self): + def test_metadata_args_with_none_values_should_fail(self) -> None: """Test that MetadataArgs validation should reject None values.""" # This test demonstrates the expected behavior - should fail validation with pytest.raises((ValueError, TypeError)): # This should fail because Pydantic expects non-None values - MetadataArgs(type=None, name=None) + MetadataArgs(type=None, name=None) # pyrefly: ignore[bad-argument-type] - def test_metadata_service_create_with_none_name_crashes(self): + def test_metadata_service_create_with_none_name_crashes(self) -> None: """Test that MetadataService.create_metadata crashes when name is None.""" # Mock the MetadataArgs to bypass Pydantic validation mock_metadata_args = Mock() mock_metadata_args.name = None # This will cause len() to crash mock_metadata_args.type = "string" - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + account = _make_account() + # This should crash with TypeError when calling len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.create_metadata("dataset-123", mock_metadata_args, account, "tenant-123") - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - # This should crash with TypeError when calling len(None) - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.create_metadata("dataset-123", mock_metadata_args) - - def test_metadata_service_update_with_none_name_crashes(self): + def test_metadata_service_update_with_none_name_crashes(self) -> None: """Test that MetadataService.update_metadata_name crashes when name is None.""" - mock_user = create_autospec(Account, instance=True) - mock_user.current_tenant_id = "tenant-123" - mock_user.id = "user-456" + account = _make_account() + none_name = cast(str, None) + # This should crash with TypeError when calling len(None) + with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): + MetadataService.update_metadata_name("dataset-123", "metadata-456", none_name, account, "tenant-123") - with patch( - "services.metadata_service.current_account_with_tenant", - return_value=(mock_user, mock_user.current_tenant_id), - ): - # This should crash with TypeError when calling len(None) - with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): - MetadataService.update_metadata_name("dataset-123", "metadata-456", None) - - def test_api_layer_now_uses_pydantic_validation(self): + def test_api_layer_now_uses_pydantic_validation(self) -> None: """Verify that API layer relies on Pydantic validation instead of reqparse.""" invalid_payload = {"type": None, "name": None} with pytest.raises((ValueError, TypeError)):