From ecd830083a1a395fd6150cf6db81b3c3d9b20372 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 8 May 2026 10:06:25 +0900 Subject: [PATCH] test: add type to test (#35871) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../controllers/console/auth/test_oauth.py | 18 +-- .../console/auth/test_password_reset.py | 10 +- .../rag_pipeline/test_rag_pipeline.py | 18 +-- .../rag_pipeline/test_rag_pipeline_import.py | 14 +- .../test_rag_pipeline_workflow.py | 2 +- .../console/datasets/test_data_source.py | 34 ++--- .../console/explore/test_conversation.py | 20 +-- .../workspace/test_trigger_providers.py | 56 ++++---- .../service_api/dataset/test_dataset.py | 54 ++++---- .../controllers/web/test_wraps.py | 8 +- .../rag/pipeline/test_queue_integration.py | 22 ++-- .../services/test_app_generate_service.py | 3 +- .../services/test_messages_clean_service.py | 2 +- .../services/test_saved_message_service.py | 27 ++-- .../services/test_web_conversation_service.py | 4 +- .../console/auth/test_account_activation.py | 12 +- .../console/auth/test_email_verification.py | 12 +- .../console/auth/test_login_logout.py | 4 +- .../console/auth/test_token_refresh.py | 8 +- .../rag_pipeline/test_datasource_auth.py | 67 +++++----- .../test_datasource_content_preview.py | 9 +- .../console/datasets/test_datasets.py | 123 +++++++++--------- .../datasets/test_datasets_document.py | 9 +- .../datasets/test_datasets_segments.py | 63 ++++----- .../console/datasets/test_external.py | 22 ++-- .../console/datasets/test_metadata.py | 2 +- .../console/explore/test_banner.py | 8 +- .../console/explore/test_message.py | 27 ++-- .../console/explore/test_recommended_app.py | 10 +- .../console/explore/test_saved_message.py | 3 +- .../controllers/console/explore/test_trial.py | 15 ++- .../controllers/console/tag/test_tags.py | 10 +- .../controllers/console/test_files.py | 4 +- .../console/test_workspace_account.py | 22 ++-- .../console/test_workspace_members.py | 2 +- .../console/workspace/test_accounts.py | 33 ++--- .../console/workspace/test_agent_providers.py | 13 +- .../console/workspace/test_endpoint.py | 49 +++---- .../console/workspace/test_members.py | 55 ++++---- .../console/workspace/test_model_providers.py | 39 +++--- .../console/workspace/test_models.py | 12 +- .../console/workspace/test_plugin.py | 101 +++++++------- .../console/workspace/test_workspace.py | 57 ++++---- .../controllers/service_api/app/test_app.py | 25 ++-- .../controllers/service_api/app/test_audio.py | 11 +- .../service_api/app/test_completion.py | 13 +- .../service_api/app/test_conversation.py | 17 +-- .../controllers/service_api/app/test_file.py | 7 +- .../service_api/app/test_hitl_service_api.py | 3 +- .../service_api/app/test_message.py | 17 +-- .../service_api/app/test_workflow.py | 27 ++-- .../service_api/app/test_workflow_events.py | 11 +- .../test_rag_pipeline_workflow.py | 21 +-- .../dataset/test_dataset_segment.py | 55 ++++---- .../service_api/dataset/test_metadata.py | 29 +++-- .../controllers/service_api/test_index.py | 5 +- .../controllers/service_api/test_wraps.py | 32 ++--- .../test_external_data_fetch.py | 6 +- 58 files changed, 706 insertions(+), 656 deletions(-) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 01d88d247c..55b6a919d8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -90,7 +90,7 @@ class TestOAuthLogin: mock_redirect, mock_get_providers, resource, - app, + app: Flask, mock_oauth_provider, invite_token, expected_token, @@ -165,7 +165,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" @@ -218,7 +218,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" @@ -262,7 +262,7 @@ class TestOAuthCallback: mock_tenant_service, mock_account_service, resource, - app, + app: Flask, oauth_setup, account_status, expected_redirect, @@ -301,7 +301,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -337,7 +337,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): """Defensive test for CLOSED account status handling in OAuth callback. @@ -466,7 +466,7 @@ class TestAccountGeneration: mock_register_service, mock_feature_service, mock_get_account, - app, + app: Flask, user_info, mock_account, allow_register, @@ -505,7 +505,7 @@ class TestAccountGeneration: mock_register_service, mock_feature_service, mock_get_account, - app, + app: Flask, ): user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") mock_feature_service.get_system_features.return_value.is_allow_register = True @@ -530,7 +530,7 @@ class TestAccountGeneration: mock_feature_service, mock_tenant_service, mock_get_account, - app, + app: Flask, user_info, mock_account, ): diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8d6b25b5b3..d017e8f2bd 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -47,7 +47,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - app, + app: Flask, mock_account, ): # Arrange @@ -105,7 +105,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - app, + app: Flask, mock_account, language_input, expected_language, @@ -154,7 +154,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - app, + app: Flask, ): """ Test successful verification code validation. @@ -201,7 +201,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} @@ -345,7 +345,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - app, + app: Flask, mock_account, ): """ diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 2752e6b34f..7aa4aff1cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -30,7 +30,7 @@ class TestPipelineTemplateListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = PipelineTemplateListApi() method = unwrap(api.get) @@ -54,7 +54,7 @@ class TestPipelineTemplateDetailApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -75,7 +75,7 @@ class TestPipelineTemplateDetailApi: assert status == 200 assert response == template - def test_get_returns_404_when_template_not_found(self, app): + def test_get_returns_404_when_template_not_found(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -94,7 +94,7 @@ class TestPipelineTemplateDetailApi: assert status == 404 assert "error" in response - def test_get_returns_404_for_customized_type_not_found(self, app): + def test_get_returns_404_for_customized_type_not_found(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -119,7 +119,7 @@ class TestCustomizedPipelineTemplateApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.patch) @@ -141,7 +141,7 @@ class TestCustomizedPipelineTemplateApi: update_mock.assert_called_once() assert response == 200 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.delete) @@ -156,7 +156,7 @@ class TestCustomizedPipelineTemplateApi: delete_mock.assert_called_once_with("tpl-1") assert response == 200 - def test_post_success(self, app, db_session_with_containers: Session): + def test_post_success(self, app: Flask, db_session_with_containers: Session): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) @@ -183,7 +183,7 @@ class TestCustomizedPipelineTemplateApi: assert status == 200 assert response == {"data": "yaml-data"} - def test_post_template_not_found(self, app): + def test_post_template_not_found(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) @@ -197,7 +197,7 @@ class TestPublishCustomizedPipelineTemplateApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = PublishCustomizedPipelineTemplateApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index f238ca13ee..44eb5c336c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -36,7 +36,7 @@ class TestRagPipelineImportApi: "name": "Test", } - def test_post_success_200(self, app): + def test_post_success_200(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -66,7 +66,7 @@ class TestRagPipelineImportApi: assert status == 200 assert response == {"status": "success"} - def test_post_failed_400(self, app): + def test_post_failed_400(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -96,7 +96,7 @@ class TestRagPipelineImportApi: assert status == 400 assert response == {"status": "failed"} - def test_post_pending_202(self, app): + def test_post_pending_202(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -132,7 +132,7 @@ class TestRagPipelineImportConfirmApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_confirm_success(self, app): + def test_confirm_success(self, app: Flask): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -160,7 +160,7 @@ class TestRagPipelineImportConfirmApi: assert status == 200 assert response == {"ok": True} - def test_confirm_failed(self, app): + def test_confirm_failed(self, app: Flask): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -194,7 +194,7 @@ class TestRagPipelineImportCheckDependenciesApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = RagPipelineImportCheckDependenciesApi() method = unwrap(api.get) @@ -223,7 +223,7 @@ class TestRagPipelineExportApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_with_include_secret(self, app): + def test_get_with_include_secret(self, app: Flask): api = RagPipelineExportApi() method = unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 1fdb3057b8..c17a83cad3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -391,7 +391,7 @@ class TestPublishedPipelineApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_publish_success(self, app, db_session_with_containers: Session): + def test_publish_success(self, app: Flask, db_session_with_containers: Session): from models.dataset import Pipeline api = PublishedRagPipelineApi() diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 50ad92afa1..b59009f7c4 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -55,7 +55,7 @@ class TestDataSourceApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceApi() method = unwrap(api.get) @@ -79,7 +79,7 @@ class TestDataSourceApi: assert status == 200 assert response["data"][0]["is_bound"] is True - def test_get_no_bindings(self, app, patch_tenant): + def test_get_no_bindings(self, app: Flask, patch_tenant): api = DataSourceApi() method = unwrap(api.get) @@ -95,7 +95,7 @@ class TestDataSourceApi: assert status == 200 assert response["data"] == [] - def test_patch_enable_binding(self, app, patch_tenant, mock_engine): + def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -116,7 +116,7 @@ class TestDataSourceApi: assert status == 200 assert binding.disabled is False - def test_patch_disable_binding(self, app, patch_tenant, mock_engine): + def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -137,7 +137,7 @@ class TestDataSourceApi: assert status == 200 assert binding.disabled is True - def test_patch_binding_not_found(self, app, patch_tenant, mock_engine): + def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -152,7 +152,7 @@ class TestDataSourceApi: with pytest.raises(NotFound): method(api, "b1", "enable") - def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine): + def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -169,7 +169,7 @@ class TestDataSourceApi: with pytest.raises(ValueError): method(api, "b1", "enable") - def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): + def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -192,7 +192,7 @@ class TestDataSourceNotionListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_credential_not_found(self, app, patch_tenant): + def test_get_credential_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -206,7 +206,7 @@ class TestDataSourceNotionListApi: with pytest.raises(NotFound): method(api) - def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine): + def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -247,7 +247,7 @@ class TestDataSourceNotionListApi: assert status == 200 - def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine): + def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -300,7 +300,7 @@ class TestDataSourceNotionListApi: assert status == 200 - def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine): + def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -327,7 +327,7 @@ class TestDataSourceNotionApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_preview_success(self, app, patch_tenant): + def test_get_preview_success(self, app: Flask, patch_tenant): api = DataSourceNotionApi() method = unwrap(api.get) @@ -348,7 +348,7 @@ class TestDataSourceNotionApi: assert status == 200 - def test_post_indexing_estimate_success(self, app, patch_tenant): + def test_post_indexing_estimate_success(self, app: Flask, patch_tenant): api = DataSourceNotionApi() method = unwrap(api.post) @@ -385,7 +385,7 @@ class TestDataSourceNotionDatasetSyncApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceNotionDatasetSyncApi() method = unwrap(api.get) @@ -408,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi: assert status == 200 - def test_get_dataset_not_found(self, app, patch_tenant): + def test_get_dataset_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionDatasetSyncApi() method = unwrap(api.get) @@ -428,7 +428,7 @@ class TestDataSourceNotionDocumentSyncApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceNotionDocumentSyncApi() method = unwrap(api.get) @@ -451,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi: assert status == 200 - def test_get_document_not_found(self, app, patch_tenant): + def test_get_document_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionDocumentSyncApi() method = unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 0b53ca5585..917aa35fe6 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -57,7 +57,7 @@ class TestConversationListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, chat_app, user): + def test_get_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -82,7 +82,7 @@ class TestConversationListApi: assert result["has_more"] is False assert len(result["data"]) == 2 - def test_last_conversation_not_exists(self, app, chat_app, user): + def test_last_conversation_not_exists(self, app: Flask, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -98,7 +98,7 @@ class TestConversationListApi: with pytest.raises(NotFound): method(chat_app) - def test_wrong_app_mode(self, app, non_chat_app): + def test_wrong_app_mode(self, app: Flask, non_chat_app): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -112,7 +112,7 @@ class TestConversationApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_delete_success(self, app, chat_app, user): + def test_delete_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -130,7 +130,7 @@ class TestConversationApi: assert status == 204 assert body["result"] == "success" - def test_delete_not_found(self, app, chat_app, user): + def test_delete_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -146,7 +146,7 @@ class TestConversationApi: with pytest.raises(NotFound): method(chat_app, "cid") - def test_delete_wrong_app_mode(self, app, non_chat_app): + def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -160,7 +160,7 @@ class TestConversationRenameApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_rename_success(self, app, chat_app, user): + def test_rename_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -179,7 +179,7 @@ class TestConversationRenameApi: assert result["id"] == "cid" - def test_rename_not_found(self, app, chat_app, user): + def test_rename_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -201,7 +201,7 @@ class TestConversationPinApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_pin_success(self, app, chat_app, user): + def test_pin_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationPinApi() method = unwrap(api.patch) @@ -223,7 +223,7 @@ class TestConversationUnPinApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_unpin_success(self, app, chat_app, user): + def test_unpin_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationUnPinApi() method = unwrap(api.patch) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index 6efdaf2943..e41adccf3c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -49,7 +49,7 @@ class TestTriggerProviderApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_icon_success(self, app): + def test_icon_success(self, app: Flask): api = TriggerProviderIconApi() method = unwrap(api.get) @@ -63,7 +63,7 @@ class TestTriggerProviderApis: ): assert method(api, "github") == "icon" - def test_list_providers(self, app): + def test_list_providers(self, app: Flask): api = TriggerProviderListApi() method = unwrap(api.get) @@ -77,7 +77,7 @@ class TestTriggerProviderApis: ): assert method(api) == [] - def test_provider_info(self, app): + def test_provider_info(self, app: Flask): api = TriggerProviderInfoApi() method = unwrap(api.get) @@ -97,7 +97,7 @@ class TestTriggerSubscriptionListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = TriggerSubscriptionListApi() method = unwrap(api.get) @@ -111,7 +111,7 @@ class TestTriggerSubscriptionListApi: ): assert method(api, "github") == [] - def test_list_invalid_provider(self, app): + def test_list_invalid_provider(self, app: Flask): api = TriggerSubscriptionListApi() method = unwrap(api.get) @@ -132,7 +132,7 @@ class TestTriggerSubscriptionBuilderApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create_builder(self, app): + def test_create_builder(self, app: Flask): api = TriggerSubscriptionBuilderCreateApi() method = unwrap(api.post) @@ -147,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis: result = method(api, "github") assert "subscription_builder" in result - def test_get_builder(self, app): + def test_get_builder(self, app: Flask): api = TriggerSubscriptionBuilderGetApi() method = unwrap(api.get) @@ -160,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_verify_builder(self, app): + def test_verify_builder(self, app: Flask): api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) @@ -174,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"ok": True} - def test_verify_builder_error(self, app): + def test_verify_builder_error(self, app: Flask): api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) @@ -189,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis: with pytest.raises(ValueError): method(api, "github", "b1") - def test_update_builder(self, app): + def test_update_builder(self, app: Flask): api = TriggerSubscriptionBuilderUpdateApi() method = unwrap(api.post) @@ -203,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_logs(self, app): + def test_logs(self, app: Flask): api = TriggerSubscriptionBuilderLogsApi() method = unwrap(api.get) @@ -220,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert "logs" in method(api, "github", "b1") - def test_build(self, app): + def test_build(self, app: Flask): api = TriggerSubscriptionBuilderBuildApi() method = unwrap(api.post) @@ -240,7 +240,7 @@ class TestTriggerSubscriptionCrud: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_update_rename_only(self, app): + def test_update_rename_only(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -259,7 +259,7 @@ class TestTriggerSubscriptionCrud: ): assert method(api, "s1") == 200 - def test_update_not_found(self, app): + def test_update_not_found(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -274,7 +274,7 @@ class TestTriggerSubscriptionCrud: with pytest.raises(NotFoundError): method(api, "x") - def test_update_rebuild(self, app): + def test_update_rebuild(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -297,7 +297,7 @@ class TestTriggerSubscriptionCrud: ): assert method(api, "s1") == 200 - def test_delete_subscription(self, app): + def test_delete_subscription(self, app: Flask): api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -320,7 +320,7 @@ class TestTriggerSubscriptionCrud: assert result["result"] == "success" - def test_delete_subscription_value_error(self, app): + def test_delete_subscription_value_error(self, app: Flask): api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -346,7 +346,7 @@ class TestTriggerOAuthApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_authorize_success(self, app): + def test_oauth_authorize_success(self, app: Flask): api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) @@ -373,7 +373,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 200 - def test_oauth_authorize_no_client(self, app): + def test_oauth_authorize_no_client(self, app: Flask): api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) @@ -388,7 +388,7 @@ class TestTriggerOAuthApis: with pytest.raises(NotFoundError): method(api, "github") - def test_oauth_callback_forbidden(self, app): + def test_oauth_callback_forbidden(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -396,7 +396,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_success(self, app): + def test_oauth_callback_success(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -426,7 +426,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 302 - def test_oauth_callback_no_oauth_client(self, app): + def test_oauth_callback_no_oauth_client(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -450,7 +450,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_empty_credentials(self, app): + def test_oauth_callback_empty_credentials(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -484,7 +484,7 @@ class TestTriggerOAuthClientManageApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_client(self, app): + def test_get_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.get) @@ -511,7 +511,7 @@ class TestTriggerOAuthClientManageApi: result = method(api, "github") assert "configured" in result - def test_post_client(self, app): + def test_post_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.post) @@ -525,7 +525,7 @@ class TestTriggerOAuthClientManageApi: ): assert method(api, "github") == {"ok": True} - def test_delete_client(self, app): + def test_delete_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.delete) @@ -539,7 +539,7 @@ class TestTriggerOAuthClientManageApi: ): assert method(api, "github") == {"ok": True} - def test_oauth_client_post_value_error(self, app): + def test_oauth_client_post_value_error(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.post) @@ -560,7 +560,7 @@ class TestTriggerSubscriptionVerifyApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_verify_success(self, app): + def test_verify_success(self, app: Flask): api = TriggerSubscriptionVerifyApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 5791d2f6e2..b73d28e4c4 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -291,7 +291,7 @@ class TestDatasetListApiGet: mock_current_user, mock_provider_mgr, mock_marshal, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -326,7 +326,7 @@ class TestDatasetListApiPost: mock_dataset_svc, mock_current_user, mock_marshal, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -352,7 +352,7 @@ class TestDatasetListApiPost: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -390,7 +390,7 @@ class TestDatasetApiGet: mock_provider_mgr, mock_marshal, mock_perm_svc, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -440,7 +440,7 @@ class TestDatasetApiGet: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -468,7 +468,7 @@ class TestDatasetApiDelete: mock_dataset_svc, mock_current_user, mock_perm_svc, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -490,7 +490,7 @@ class TestDatasetApiDelete: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -511,7 +511,7 @@ class TestDatasetApiDelete: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -543,7 +543,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -574,7 +574,7 @@ class TestDocumentStatusApiPatch: def test_batch_update_status_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -603,7 +603,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -636,7 +636,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -669,7 +669,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -709,7 +709,7 @@ class TestDatasetTagsApiGet: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -731,7 +731,7 @@ class TestDatasetTagsApiGet: def test_list_tags_from_db( self, mock_current_user, - app, + app: Flask, db_session_with_containers: Session, ): """Integration test: creates real Tag rows and retrieves them @@ -774,7 +774,7 @@ class TestDatasetTagsApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -797,7 +797,7 @@ class TestDatasetTagsApiPost: mock_tag_svc.save_tags.assert_called_once() @patch("controllers.service_api.dataset.dataset.current_user") - def test_create_tag_forbidden(self, mock_current_user, app): + def test_create_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -826,7 +826,7 @@ class TestDatasetTagsApiPatch: mock_current_user, mock_service_api_ns, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -852,7 +852,7 @@ class TestDatasetTagsApiPatch: mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1") @patch("controllers.service_api.dataset.dataset.current_user") - def test_update_tag_forbidden(self, mock_current_user, app): + def test_update_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -880,7 +880,7 @@ class TestDatasetTagsApiDelete: mock_current_user, mock_service_api_ns, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -905,7 +905,7 @@ class TestDatasetTagsApiDelete: mock_tag_svc.delete_tag.assert_called_once_with("tag-1") @patch("libs.login.current_user") - def test_delete_tag_forbidden(self, mock_current_user, app): + def test_delete_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi user_obj = Mock(spec=Account) @@ -933,7 +933,7 @@ class TestDatasetTagsBindingStatusApi: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi @@ -963,7 +963,7 @@ class TestDatasetTagBindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagBindingApi @@ -988,7 +988,7 @@ class TestDatasetTagBindingApiPost: ) @patch("controllers.service_api.dataset.dataset.current_user") - def test_bind_tags_forbidden(self, mock_current_user, app): + def test_bind_tags_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1014,7 +1014,7 @@ class TestDatasetTagUnbindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi @@ -1044,7 +1044,7 @@ class TestDatasetTagUnbindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi @@ -1069,7 +1069,7 @@ class TestDatasetTagUnbindingApiPost: ) @patch("controllers.service_api.dataset.dataset.current_user") - def test_unbind_tag_forbidden(self, mock_current_user, app): + def test_unbind_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py index de9e691434..0a4e495f36 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py @@ -240,7 +240,7 @@ class TestDecodeJwtToken: mock_access_mode: MagicMock, mock_validate_token: MagicMock, mock_validate_user: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers) @@ -300,7 +300,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False) @@ -325,7 +325,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, _ = self._create_app_site_enduser(db_session_with_containers) @@ -351,7 +351,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index 54ee133bfe..d1af0a56ef 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -21,6 +21,8 @@ from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus +TenantAndAccount = tuple[Tenant, Account] + @dataclass class TestTask: @@ -74,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration: return tenant, account @pytest.fixture - def test_queue(self, test_tenant_and_account): + def test_queue(self, test_tenant_and_account: TenantAndAccount): """Create a generic test queue for testing.""" tenant, _ = test_tenant_and_account return TenantIsolatedTaskQueue(tenant.id, "test_queue") @pytest.fixture - def secondary_queue(self, test_tenant_and_account): + def secondary_queue(self, test_tenant_and_account: TenantAndAccount): """Create a secondary test queue for testing isolation.""" tenant, _ = test_tenant_and_account return TenantIsolatedTaskQueue(tenant.id, "secondary_queue") - def test_queue_initialization(self, test_tenant_and_account): + def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount): """Test queue initialization with correct key generation.""" tenant, _ = test_tenant_and_account queue = TenantIsolatedTaskQueue(tenant.id, "test-key") @@ -95,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" assert queue._task_key == f"tenant_test-key_task:{tenant.id}" - def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker): + def test_tenant_isolation( + self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker + ): """Test that different tenants have isolated queues.""" tenant1, _ = test_tenant_and_account @@ -115,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}" assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}" - def test_key_isolation(self, test_tenant_and_account): + def test_key_isolation(self, test_tenant_and_account: TenantAndAccount): """Test that different keys have isolated queues.""" tenant, _ = test_tenant_and_account queue1 = TenantIsolatedTaskQueue(tenant.id, "key1") @@ -293,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert isinstance(task, dict) assert task["index"] == i # FIFO order - def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker): + def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """Test concurrent operations on different queues.""" tenant, _ = test_tenant_and_account @@ -436,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return tenant, account - def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker): + def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test compatibility with legacy queues containing only string data. @@ -466,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility: expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] assert pulled_tasks == expected_order - def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker): + def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test complete migration scenario from legacy to new system. @@ -547,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility: assert task["tenant_id"] == tenant.id assert task["processing_type"] == "new_system" - def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker): + def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test error recovery when legacy queue contains malformed data. diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 3229693fd4..e2fe6c8476 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -7,6 +7,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom +from models import App from models.model import EndUser from models.workflow import Workflow from services.app_generate_service import AppGenerateService @@ -184,7 +185,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers: Session, app): + def _create_test_workflow(self, db_session_with_containers: Session, app: App): """ Helper method to create a test workflow for testing. diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index cd63d3ad6c..1a1efe0337 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration: return app - def _create_conversation(self, db_session_with_containers: Session, app): + def _create_conversation(self, db_session_with_containers: Session, app: App): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 70aa813142..7b9e9924cd 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models import App, CreatorUserRole from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage @@ -88,7 +89,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers: Session, app): + def _create_test_end_user(self, db_session_with_containers: Session, app: App): """ Helper method to create a test end user for testing. @@ -116,7 +117,7 @@ class TestSavedMessageService: return end_user - def _create_test_message(self, db_session_with_containers: Session, app, user): + def _create_test_message(self, db_session_with_containers: Session, app: App, user): """ Helper method to create a test message for testing. @@ -199,13 +200,13 @@ class TestSavedMessageService: saved_message1 = SavedMessage( app_id=app.id, message_id=message1.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) saved_message2 = SavedMessage( app_id=app.id, message_id=message2.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -272,13 +273,13 @@ class TestSavedMessageService: saved_message1 = SavedMessage( app_id=app.id, message_id=message1.id, - created_by_role="end_user", + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) saved_message2 = SavedMessage( app_id=app.id, message_id=message2.id, - created_by_role="end_user", + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -449,7 +450,7 @@ class TestSavedMessageService: saved_message = SavedMessage( app_id=app.id, message_id=message.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -540,7 +541,9 @@ class TestSavedMessageService: message = self._create_test_message(db_session_with_containers, app, account) # Pre-create a saved message - saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + saved = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id + ) db_session_with_containers.add(saved) db_session_with_containers.commit() @@ -571,7 +574,9 @@ class TestSavedMessageService: end_user = self._create_test_end_user(db_session_with_containers, app) message = self._create_test_message(db_session_with_containers, app, end_user) - saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + saved = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id + ) db_session_with_containers.add(saved) db_session_with_containers.commit() @@ -596,10 +601,10 @@ class TestSavedMessageService: # Both users save the same message saved_account = SavedMessage( - app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account1.id ) saved_end_user = SavedMessage( - app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id ) db_session_with_containers.add_all([saved_account, saved_end_user]) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index f2307fbd7d..797731d04b 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account +from models import Account, App from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation @@ -93,7 +93,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers: Session, app): + def _create_test_end_user(self, db_session_with_containers: Session, app: App): """ Helper method to create a test end user for testing. diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index 78413a0798..0fb0ebc330 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -185,7 +185,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): @@ -263,7 +263,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): @@ -312,7 +312,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, language, @@ -358,7 +358,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, ): """ @@ -398,7 +398,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, ): """ @@ -438,7 +438,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 7b2c7569fe..102af9b250 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -195,7 +195,7 @@ class TestEmailCodeLoginSendEmailApi: mock_get_user, mock_is_ip_limit, mock_db, - app, + app: Flask, mock_account, language_input, expected_language, @@ -267,7 +267,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -315,7 +315,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -431,7 +431,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -474,7 +474,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -515,7 +515,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 5284f29eed..ace2ce5706 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -412,7 +412,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -448,7 +448,7 @@ class TestLoginApi: mock_revoke_token, mock_get_token_data, mock_db, - app, + app: Flask, ): mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index 15c95f6b94..22974ca416 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -74,7 +74,7 @@ class TestRefreshTokenApi: assert response.json["result"] == "success" @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) - def test_refresh_fails_without_token(self, mock_extract_token, app): + def test_refresh_fails_without_token(self, mock_extract_token, app: Flask): """ Test token refresh failure when no refresh token provided. @@ -98,7 +98,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh failure with invalid refresh token. @@ -123,7 +123,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh failure with expired refresh token. @@ -148,7 +148,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh with empty string token. diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 5136922e88..9c5b5ec256 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -29,7 +30,7 @@ def unwrap(func): class TestDatasourcePluginOAuthAuthorizationUrl: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -61,7 +62,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: assert response.status_code == 200 - def test_get_no_oauth_config(self, app): + def test_get_no_oauth_config(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -80,7 +81,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: with pytest.raises(ValueError): method(api, "notion") - def test_get_without_credential_id_sets_cookie(self, app): + def test_get_without_credential_id_sets_cookie(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -115,7 +116,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: class TestDatasourceOAuthCallback: - def test_callback_success_new_credential(self, app): + def test_callback_success_new_credential(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -157,7 +158,7 @@ class TestDatasourceOAuthCallback: assert response.status_code == 302 - def test_callback_missing_context(self, app): + def test_callback_missing_context(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -165,7 +166,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(Forbidden): method(api, "notion") - def test_callback_invalid_context(self, app): + def test_callback_invalid_context(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -180,7 +181,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(Forbidden): method(api, "notion") - def test_callback_oauth_config_not_found(self, app): + def test_callback_oauth_config_not_found(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -202,7 +203,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(NotFound): method(api, "notion") - def test_callback_reauthorize_existing_credential(self, app): + def test_callback_reauthorize_existing_credential(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -245,7 +246,7 @@ class TestDatasourceOAuthCallback: assert response.status_code == 302 assert "/oauth-callback" in response.location - def test_callback_context_id_from_cookie(self, app): + def test_callback_context_id_from_cookie(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -289,7 +290,7 @@ class TestDatasourceOAuthCallback: class TestDatasourceAuth: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -312,7 +313,7 @@ class TestDatasourceAuth: assert status == 200 - def test_post_invalid_credentials(self, app): + def test_post_invalid_credentials(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -334,7 +335,7 @@ class TestDatasourceAuth: with pytest.raises(ValueError): method(api, "notion") - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasourceAuth() method = unwrap(api.get) @@ -355,7 +356,7 @@ class TestDatasourceAuth: assert status == 200 assert response["result"] - def test_post_missing_credentials(self, app): + def test_post_missing_credentials(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -372,7 +373,7 @@ class TestDatasourceAuth: with pytest.raises(ValueError): method(api, "notion") - def test_get_empty_list(self, app): + def test_get_empty_list(self, app: Flask): api = DatasourceAuth() method = unwrap(api.get) @@ -395,7 +396,7 @@ class TestDatasourceAuth: class TestDatasourceAuthDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasourceAuthDeleteApi() method = unwrap(api.post) @@ -418,7 +419,7 @@ class TestDatasourceAuthDeleteApi: assert status == 200 - def test_delete_missing_credential_id(self, app): + def test_delete_missing_credential_id(self, app: Flask): api = DatasourceAuthDeleteApi() method = unwrap(api.post) @@ -437,7 +438,7 @@ class TestDatasourceAuthDeleteApi: class TestDatasourceAuthUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestDatasourceAuthUpdateApi: assert status == 201 - def test_update_with_credentials_none(self, app): + def test_update_with_credentials_none(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -484,7 +485,7 @@ class TestDatasourceAuthUpdateApi: update_mock.assert_called_once() assert status == 201 - def test_update_name_only(self, app): + def test_update_name_only(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -507,7 +508,7 @@ class TestDatasourceAuthUpdateApi: assert status == 201 - def test_update_with_empty_credentials_dict(self, app): + def test_update_with_empty_credentials_dict(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestDatasourceAuthUpdateApi: class TestDatasourceAuthListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = DatasourceAuthListApi() method = unwrap(api.get) @@ -553,7 +554,7 @@ class TestDatasourceAuthListApi: assert status == 200 - def test_auth_list_empty(self, app): + def test_auth_list_empty(self, app: Flask): api = DatasourceAuthListApi() method = unwrap(api.get) @@ -574,7 +575,7 @@ class TestDatasourceAuthListApi: assert status == 200 assert response["result"] == [] - def test_hardcode_list_empty(self, app): + def test_hardcode_list_empty(self, app: Flask): api = DatasourceHardCodeAuthListApi() method = unwrap(api.get) @@ -597,7 +598,7 @@ class TestDatasourceAuthListApi: class TestDatasourceHardCodeAuthListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = DatasourceHardCodeAuthListApi() method = unwrap(api.get) @@ -619,7 +620,7 @@ class TestDatasourceHardCodeAuthListApi: class TestDatasourceAuthOauthCustomClient: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -642,7 +643,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.delete) @@ -662,7 +663,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_post_empty_payload(self, app): + def test_post_empty_payload(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -685,7 +686,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_post_disabled_flag(self, app): + def test_post_disabled_flag(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -714,7 +715,7 @@ class TestDatasourceAuthOauthCustomClient: class TestDatasourceAuthDefaultApi: - def test_set_default_success(self, app): + def test_set_default_success(self, app: Flask): api = DatasourceAuthDefaultApi() method = unwrap(api.post) @@ -737,7 +738,7 @@ class TestDatasourceAuthDefaultApi: assert status == 200 - def test_default_missing_id(self, app): + def test_default_missing_id(self, app: Flask): api = DatasourceAuthDefaultApi() method = unwrap(api.post) @@ -756,7 +757,7 @@ class TestDatasourceAuthDefaultApi: class TestDatasourceUpdateProviderNameApi: - def test_update_name_success(self, app): + def test_update_name_success(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) @@ -779,7 +780,7 @@ class TestDatasourceUpdateProviderNameApi: assert status == 200 - def test_update_name_too_long(self, app): + def test_update_name_too_long(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) @@ -799,7 +800,7 @@ class TestDatasourceUpdateProviderNameApi: with pytest.raises(ValueError): method(api, "notion") - def test_update_name_missing_credential_id(self, app): + def test_update_name_missing_credential_id(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py index 7a8ccde55a..d4c6a775ec 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console import console_ns @@ -25,7 +26,7 @@ class TestDataSourceContentPreviewApi: "credential_id": "cred-1", } - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -66,7 +67,7 @@ class TestDataSourceContentPreviewApi: assert status == 200 assert response == preview_result - def test_post_forbidden_non_account_user(self, app): + def test_post_forbidden_non_account_user(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -85,7 +86,7 @@ class TestDataSourceContentPreviewApi: with pytest.raises(Forbidden): method(api, pipeline, "node-1") - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -108,7 +109,7 @@ class TestDataSourceContentPreviewApi: with pytest.raises(ValueError): method(api, pipeline, "node-1") - def test_post_without_credential_id(self, app): + def test_post_without_credential_id(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 9465936f28..e28d68ee5a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -2,6 +2,7 @@ import datetime from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden, NotFound import services @@ -58,7 +59,7 @@ class TestDatasetList: user.is_dataset_editor = True return user - def test_get_success_basic(self, app): + def test_get_success_basic(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -93,7 +94,7 @@ class TestDatasetList: assert resp["total"] == 1 assert resp["data"][0]["embedding_available"] is True - def test_get_with_ids_filter(self, app): + def test_get_with_ids_filter(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -128,7 +129,7 @@ class TestDatasetList: assert status == 200 assert resp["total"] == 2 - def test_get_with_tag_ids(self, app): + def test_get_with_tag_ids(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -161,7 +162,7 @@ class TestDatasetList: assert status == 200 - def test_embedding_available_false(self, app): + def test_embedding_available_false(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -203,7 +204,7 @@ class TestDatasetList: assert resp["data"][0]["embedding_available"] is False - def test_partial_members_permission(self, app): + def test_partial_members_permission(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -242,7 +243,7 @@ class TestDatasetList: class TestDatasetListApiPost: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -290,7 +291,7 @@ class TestDatasetListApiPost: assert status == 201 - def test_post_forbidden(self, app): + def test_post_forbidden(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -310,7 +311,7 @@ class TestDatasetListApiPost: with pytest.raises(Forbidden): method(api) - def test_post_duplicate_name(self, app): + def test_post_duplicate_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -335,7 +336,7 @@ class TestDatasetListApiPost: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload_missing_name(self, app): + def test_post_invalid_payload_missing_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -343,7 +344,7 @@ class TestDatasetListApiPost: with pytest.raises(ValueError): method(api) - def test_post_invalid_indexing_technique(self, app): + def test_post_invalid_indexing_technique(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -356,7 +357,7 @@ class TestDatasetListApiPost: with pytest.raises(ValueError, match="Invalid indexing technique"): method(api) - def test_post_invalid_provider(self, app): + def test_post_invalid_provider(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -371,7 +372,7 @@ class TestDatasetListApiPost: class TestDatasetApiGet: - def test_get_success_basic(self, app): + def test_get_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -427,7 +428,7 @@ class TestDatasetApiGet: assert status == 200 assert data["embedding_available"] is True - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -448,7 +449,7 @@ class TestDatasetApiGet: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -475,7 +476,7 @@ class TestDatasetApiGet: with pytest.raises(Forbidden, match="no access"): method(api, dataset_id) - def test_get_high_quality_embedding_unavailable(self, app): + def test_get_high_quality_embedding_unavailable(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -530,7 +531,7 @@ class TestDatasetApiGet: assert data["embedding_available"] is False - def test_get_partial_members_permission(self, app): + def test_get_partial_members_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -590,7 +591,7 @@ class TestDatasetApiGet: class TestDatasetApiPatch: - def test_patch_success_basic(self, app): + def test_patch_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -659,7 +660,7 @@ class TestDatasetApiPatch: assert status == 200 assert result["partial_member_list"] == [] - def test_patch_dataset_not_found(self, app): + def test_patch_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -674,7 +675,7 @@ class TestDatasetApiPatch: with pytest.raises(NotFound, match="Dataset not found"): method(api, "missing") - def test_patch_permission_denied(self, app): + def test_patch_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -704,7 +705,7 @@ class TestDatasetApiPatch: with pytest.raises(Forbidden): method(api, dataset_id) - def test_patch_partial_members_update(self, app): + def test_patch_partial_members_update(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -773,7 +774,7 @@ class TestDatasetApiPatch: assert result["partial_member_list"] == payload["partial_member_list"] - def test_patch_clear_partial_members(self, app): + def test_patch_clear_partial_members(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -843,7 +844,7 @@ class TestDatasetApiPatch: class TestDatasetApiDelete: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -874,7 +875,7 @@ class TestDatasetApiDelete: assert status == 204 assert result == {"result": "success"} - def test_delete_forbidden_no_permission(self, app): + def test_delete_forbidden_no_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -893,7 +894,7 @@ class TestDatasetApiDelete: with pytest.raises(Forbidden): method(api, dataset_id) - def test_delete_dataset_not_found(self, app): + def test_delete_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -917,7 +918,7 @@ class TestDatasetApiDelete: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_delete_dataset_in_use(self, app): + def test_delete_dataset_in_use(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -943,7 +944,7 @@ class TestDatasetApiDelete: class TestDatasetUseCheckApi: - def test_get_use_check_true(self, app): + def test_get_use_check_true(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) @@ -962,7 +963,7 @@ class TestDatasetUseCheckApi: assert status == 200 assert result == {"is_using": True} - def test_get_use_check_false(self, app): + def test_get_use_check_false(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) @@ -983,7 +984,7 @@ class TestDatasetUseCheckApi: class TestDatasetQueryApi: - def test_get_queries_success(self, app): + def test_get_queries_success(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1027,7 +1028,7 @@ class TestDatasetQueryApi: assert response["has_more"] is False assert len(response["data"]) == 2 - def test_get_queries_dataset_not_found(self, app): + def test_get_queries_dataset_not_found(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1049,7 +1050,7 @@ class TestDatasetQueryApi: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_get_queries_permission_denied(self, app): + def test_get_queries_permission_denied(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1078,7 +1079,7 @@ class TestDatasetQueryApi: with pytest.raises(Forbidden): method(api, dataset_id) - def test_get_queries_pagination_has_more(self, app): + def test_get_queries_pagination_has_more(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1152,7 +1153,7 @@ class TestDatasetIndexingEstimateApi: "dataset_id": None, } - def test_post_success_upload_file(self, app): + def test_post_success_upload_file(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) @@ -1193,7 +1194,7 @@ class TestDatasetIndexingEstimateApi: assert status == 200 assert response == {"tokens": 100} - def test_post_file_not_found(self, app): + def test_post_file_not_found(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) @@ -1223,7 +1224,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(NotFound): method(api) - def test_post_llm_bad_request_error(self, app): + def test_post_llm_bad_request_error(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1258,7 +1259,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(ProviderNotInitializeError): method(api) - def test_post_provider_token_not_init(self, app): + def test_post_provider_token_not_init(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1293,7 +1294,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(ProviderNotInitializeError): method(api) - def test_post_generic_exception(self, app): + def test_post_generic_exception(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1330,7 +1331,7 @@ class TestDatasetIndexingEstimateApi: class TestDatasetRelatedAppListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1368,7 +1369,7 @@ class TestDatasetRelatedAppListApi: assert response["total"] == 2 assert response["data"] == [app1, app2] - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1386,7 +1387,7 @@ class TestDatasetRelatedAppListApi: with pytest.raises(NotFound): method(api, "dataset-1") - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1410,7 +1411,7 @@ class TestDatasetRelatedAppListApi: with pytest.raises(Forbidden): method(api, "dataset-1") - def test_get_filters_none_apps(self, app): + def test_get_filters_none_apps(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1449,7 +1450,7 @@ class TestDatasetRelatedAppListApi: class TestDatasetIndexingStatusApi: - def test_get_success_with_documents(self, app): + def test_get_success_with_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1490,7 +1491,7 @@ class TestDatasetIndexingStatusApi: assert item["completed_segments"] == 3 assert item["total_segments"] == 3 - def test_get_success_no_documents(self, app): + def test_get_success_no_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1510,7 +1511,7 @@ class TestDatasetIndexingStatusApi: assert status == 200 assert response == {"data": []} - def test_segment_counts_different_values(self, app): + def test_segment_counts_different_values(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1550,7 +1551,7 @@ class TestDatasetIndexingStatusApi: class TestDatasetApiKeyApi: - def test_get_api_keys_success(self, app): + def test_get_api_keys_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.get) @@ -1587,7 +1588,7 @@ class TestDatasetApiKeyApi: assert response["data"][1]["id"] == "key-2" assert response["data"][1]["token"] == "ds-def" - def test_post_create_api_key_success(self, app): + def test_post_create_api_key_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) @@ -1632,7 +1633,7 @@ class TestDatasetApiKeyApi: assert response["type"] == "dataset" assert response["created_at"] is not None - def test_post_exceed_max_keys(self, app): + def test_post_exceed_max_keys(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) @@ -1658,7 +1659,7 @@ class TestDatasetApiKeyApi: class TestDatasetApiDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) @@ -1688,7 +1689,7 @@ class TestDatasetApiDeleteApi: assert status == 204 assert response["result"] == "success" - def test_delete_key_not_found(self, app): + def test_delete_key_not_found(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) @@ -1708,7 +1709,7 @@ class TestDatasetApiDeleteApi: class TestDatasetEnableApiApi: - def test_enable_api(self, app): + def test_enable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) @@ -1724,7 +1725,7 @@ class TestDatasetEnableApiApi: assert status == 200 assert response["result"] == "success" - def test_disable_api(self, app): + def test_disable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) @@ -1742,7 +1743,7 @@ class TestDatasetEnableApiApi: class TestDatasetApiBaseUrlApi: - def test_get_api_base_url_from_config(self, app): + def test_get_api_base_url_from_config(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1757,7 +1758,7 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "https://example.com/v1" - def test_get_api_base_url_from_request(self, app): + def test_get_api_base_url_from_request(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1772,7 +1773,7 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" - def test_get_api_base_url_no_double_v1(self, app): + def test_get_api_base_url_no_double_v1(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1789,7 +1790,7 @@ class TestDatasetApiBaseUrlApi: class TestDatasetRetrievalSettingApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRetrievalSettingApi() method = unwrap(api.get) @@ -1810,7 +1811,7 @@ class TestDatasetRetrievalSettingApi: class TestDatasetRetrievalSettingMockApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRetrievalSettingMockApi() method = unwrap(api.get) @@ -1827,7 +1828,7 @@ class TestDatasetRetrievalSettingMockApi: class TestDatasetErrorDocs: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) @@ -1850,7 +1851,7 @@ class TestDatasetErrorDocs: assert status == 200 assert response["total"] == 1 - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) @@ -1866,7 +1867,7 @@ class TestDatasetErrorDocs: class TestDatasetPermissionUserListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) @@ -1897,7 +1898,7 @@ class TestDatasetPermissionUserListApi: assert status == 200 assert response["data"] == users - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) @@ -1923,7 +1924,7 @@ class TestDatasetPermissionUserListApi: class TestDatasetAutoDisableLogApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) @@ -1946,7 +1947,7 @@ class TestDatasetAutoDisableLogApi: assert status == 200 assert response == logs - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index d9b02ac453..ff9e1736d2 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound import services @@ -239,7 +240,7 @@ class TestDatasetDocumentListApi: assert "documents" in response - def test_post_forbidden(self, app): + def test_post_forbidden(self, app: Flask): api = DatasetDocumentListApi() method = unwrap(api.post) @@ -395,7 +396,7 @@ class TestDocumentDownloadApi: class TestDocumentProcessingApi: - def test_processing_forbidden_when_not_editor(self, app): + def test_processing_forbidden_when_not_editor(self, app: Flask): api = DocumentProcessingApi() method = unwrap(api.patch) @@ -1185,7 +1186,7 @@ class TestDocumentPermissionCases: "preview": [], } - def test_document_tenant_mismatch(self, app): + def test_document_tenant_mismatch(self, app: Flask): api = DocumentApi() method = unwrap(api.get) @@ -1253,7 +1254,7 @@ class TestDocumentPermissionCases: assert status == 200 assert response["mode"] == "custom" - def test_process_rule_permission_denied(self, app): + def test_process_rule_permission_denied(self, app: Flask): api = GetProcessRuleApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 693b06e95b..412edb9dfe 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound import services @@ -82,7 +83,7 @@ def test_get_segment_with_summary(monkeypatch): class TestDatasetDocumentSegmentListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -132,7 +133,7 @@ class TestDatasetDocumentSegmentListApi: assert status == 200 - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -150,7 +151,7 @@ class TestDatasetDocumentSegmentListApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -176,7 +177,7 @@ class TestDatasetDocumentSegmentListApi: class TestDatasetDocumentSegmentApi: - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -221,7 +222,7 @@ class TestDatasetDocumentSegmentApi: assert status == 200 assert response["result"] == "success" - def test_patch_document_indexing_in_progress(self, app): + def test_patch_document_indexing_in_progress(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -264,7 +265,7 @@ class TestDatasetDocumentSegmentApi: with pytest.raises(InvalidActionError): method(api, "ds-1", "doc-1", "disable") - def test_patch_llm_bad_request(self, app): + def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -308,7 +309,7 @@ class TestDatasetDocumentSegmentApi: with pytest.raises(ProviderNotInitializeError): method(api, "ds-1", "doc-1", "enable") - def test_patch_provider_token_not_init(self, app): + def test_patch_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -354,7 +355,7 @@ class TestDatasetDocumentSegmentApi: class TestDatasetDocumentSegmentAddApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -413,7 +414,7 @@ class TestDatasetDocumentSegmentAddApi: assert status == 200 assert response["data"]["id"] == "seg-1" - def test_post_llm_bad_request(self, app): + def test_post_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -452,7 +453,7 @@ class TestDatasetDocumentSegmentAddApi: with pytest.raises(ProviderNotInitializeError): method(api, "ds-1", "doc-1") - def test_post_provider_token_not_init(self, app): + def test_post_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -493,7 +494,7 @@ class TestDatasetDocumentSegmentAddApi: class TestDatasetDocumentSegmentUpdateApi: - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() method = unwrap(api.patch) @@ -551,7 +552,7 @@ class TestDatasetDocumentSegmentUpdateApi: assert status == 200 assert "data" in response - def test_patch_llm_bad_request(self, app): + def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() method = unwrap(api.patch) @@ -596,7 +597,7 @@ class TestDatasetDocumentSegmentUpdateApi: class TestDatasetDocumentSegmentBatchImportApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -638,7 +639,7 @@ class TestDatasetDocumentSegmentBatchImportApi: assert status == 200 assert response["job_status"] == "waiting" - def test_post_dataset_not_found(self, app): + def test_post_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -659,7 +660,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_document_not_found(self, app): + def test_post_document_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -684,7 +685,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_upload_file_not_found(self, app): + def test_post_upload_file_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -713,7 +714,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_invalid_file_type(self, app): + def test_post_invalid_file_type(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -745,7 +746,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(ValueError): method(api, "ds-1", "doc-1") - def test_post_async_task_failure(self, app): + def test_post_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -783,7 +784,7 @@ class TestDatasetDocumentSegmentBatchImportApi: assert status == 500 assert "error" in response - def test_get_job_not_found_in_redis(self, app): + def test_get_job_not_found_in_redis(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.get) @@ -799,7 +800,7 @@ class TestDatasetDocumentSegmentBatchImportApi: class TestChildChunkAddApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = ChildChunkAddApi() method = unwrap(api.post) @@ -852,7 +853,7 @@ class TestChildChunkAddApi: assert status == 200 assert response["data"]["id"] == "cc-1" - def test_post_child_chunk_indexing_error(self, app): + def test_post_child_chunk_indexing_error(self, app: Flask): api = ChildChunkAddApi() method = unwrap(api.post) @@ -897,7 +898,7 @@ class TestChildChunkAddApi: class TestChildChunkUpdateApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ChildChunkUpdateApi() method = unwrap(api.delete) @@ -941,7 +942,7 @@ class TestChildChunkUpdateApi: assert status == 204 assert response["result"] == "success" - def test_delete_child_chunk_index_error(self, app): + def test_delete_child_chunk_index_error(self, app: Flask): api = ChildChunkUpdateApi() method = unwrap(api.delete) @@ -984,7 +985,7 @@ class TestChildChunkUpdateApi: class TestSegmentListAdvancedCases: - def test_segment_list_with_keyword_filter(self, app): + def test_segment_list_with_keyword_filter(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1035,7 +1036,7 @@ class TestSegmentListAdvancedCases: assert status == 200 assert response["total"] == 1 - def test_segment_list_permission_denied(self, app): + def test_segment_list_permission_denied(self, app: Flask): """Test segment list with permission denied""" api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1058,7 +1059,7 @@ class TestSegmentListAdvancedCases: with pytest.raises(Forbidden): method(api, "ds-1", "doc-1") - def test_segment_list_dataset_not_found(self, app): + def test_segment_list_dataset_not_found(self, app: Flask): """Test segment list with dataset not found""" api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1079,7 +1080,7 @@ class TestSegmentListAdvancedCases: class TestSegmentOperationCases: - def test_segment_add_with_provider_token_error(self, app): + def test_segment_add_with_provider_token_error(self, app: Flask): """Test segment add with provider token not initialized""" api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -1117,7 +1118,7 @@ class TestSegmentOperationCases: with pytest.raises(ProviderTokenNotInitError): method(api, "ds-1", "doc-1") - def test_batch_import_with_document_not_found(self, app): + def test_batch_import_with_document_not_found(self, app: Flask): """Test batch import with document not found""" api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1146,7 +1147,7 @@ class TestSegmentOperationCases: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_batch_import_with_invalid_file(self, app): + def test_batch_import_with_invalid_file(self, app: Flask): """Test batch import with invalid file type""" api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1181,7 +1182,7 @@ class TestSegmentOperationCases: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_batch_import_with_async_task_failure(self, app): + def test_batch_import_with_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1226,7 +1227,7 @@ class TestSegmentOperationCases: assert status == 500 assert "error" in response - def test_batch_import_get_job_not_found(self, app): + def test_batch_import_get_job_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 514bbbe040..7254bf7670 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -57,7 +57,7 @@ def mock_auth(monkeypatch, current_user): class TestExternalApiTemplateListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ExternalApiTemplateListApi() method = unwrap(api.get) @@ -78,7 +78,7 @@ class TestExternalApiTemplateListApi: assert resp["total"] == 1 assert resp["data"][0]["id"] == "1" - def test_post_forbidden(self, app, current_user): + def test_post_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False api = ExternalApiTemplateListApi() method = unwrap(api.post) @@ -93,7 +93,7 @@ class TestExternalApiTemplateListApi: with pytest.raises(Forbidden): method(api) - def test_post_duplicate_name(self, app): + def test_post_duplicate_name(self, app: Flask): api = ExternalApiTemplateListApi() method = unwrap(api.post) @@ -114,7 +114,7 @@ class TestExternalApiTemplateListApi: class TestExternalApiTemplateApi: - def test_get_not_found(self, app): + def test_get_not_found(self, app: Flask): api = ExternalApiTemplateApi() method = unwrap(api.get) @@ -129,7 +129,7 @@ class TestExternalApiTemplateApi: with pytest.raises(NotFound): method(api, "api-id") - def test_delete_forbidden(self, app, current_user): + def test_delete_forbidden(self, app: Flask, current_user): current_user.has_edit_permission = False current_user.is_dataset_operator = False @@ -142,7 +142,7 @@ class TestExternalApiTemplateApi: class TestExternalApiUseCheckApi: - def test_get_scopes_usage_check_to_current_tenant(self, app): + def test_get_scopes_usage_check_to_current_tenant(self, app: Flask): api = ExternalApiUseCheckApi() method = unwrap(api.get) @@ -162,7 +162,7 @@ class TestExternalApiUseCheckApi: class TestExternalDatasetCreateApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = ExternalDatasetCreateApi() method = unwrap(api.post) @@ -206,7 +206,7 @@ class TestExternalDatasetCreateApi: assert status == 201 - def test_create_forbidden(self, app, current_user): + def test_create_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False api = ExternalDatasetCreateApi() method = unwrap(api.post) @@ -226,7 +226,7 @@ class TestExternalDatasetCreateApi: class TestExternalKnowledgeHitTestingApi: - def test_hit_testing_dataset_not_found(self, app): + def test_hit_testing_dataset_not_found(self, app: Flask): api = ExternalKnowledgeHitTestingApi() method = unwrap(api.post) @@ -241,7 +241,7 @@ class TestExternalKnowledgeHitTestingApi: with pytest.raises(NotFound): method(api, "dataset-id") - def test_hit_testing_success(self, app): + def test_hit_testing_success(self, app: Flask): api = ExternalKnowledgeHitTestingApi() method = unwrap(api.post) @@ -266,7 +266,7 @@ class TestExternalKnowledgeHitTestingApi: class TestBedrockRetrievalApi: - def test_bedrock_retrieval(self, app): + def test_bedrock_retrieval(self, app: Flask): api = BedrockRetrievalApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py index de834c2d4d..0105aacd65 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -269,7 +269,7 @@ class TestDatasetMetadataApi: class TestDatasetMetadataBuiltInFieldApi: - def test_get_built_in_fields(self, app): + def test_get_built_in_fields(self, app: Flask): api = DatasetMetadataBuiltInFieldApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index c8f674f515..d1cb6b6a03 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from flask import Flask + import controllers.console.explore.banner as banner_module from models.enums import BannerStatus @@ -12,7 +14,7 @@ def unwrap(func): class TestBannerApi: - def test_get_banners_with_requested_language(self, app): + def test_get_banners_with_requested_language(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) @@ -41,7 +43,7 @@ class TestBannerApi: } ] - def test_get_banners_fallback_to_en_us(self, app): + def test_get_banners_fallback_to_en_us(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) @@ -76,7 +78,7 @@ class TestBannerApi: } ] - def test_get_banners_default_language_en_us(self, app): + def test_get_banners_default_language_en_us(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 145cc9cdd7..3d41489435 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -54,7 +55,7 @@ def make_message(): class TestMessageListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -96,7 +97,7 @@ class TestMessageListApi: with pytest.raises(NotChatAppError): method(installed_app) - def test_conversation_not_exists(self, app): + def test_conversation_not_exists(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -118,7 +119,7 @@ class TestMessageListApi: with pytest.raises(NotFound): method(installed_app) - def test_first_message_not_exists(self, app): + def test_first_message_not_exists(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -142,7 +143,7 @@ class TestMessageListApi: class TestMessageFeedbackApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = module.MessageFeedbackApi() method = unwrap(api.post) @@ -161,7 +162,7 @@ class TestMessageFeedbackApi: assert result["result"] == "success" - def test_message_not_exists(self, app): + def test_message_not_exists(self, app: Flask): api = module.MessageFeedbackApi() method = unwrap(api.post) @@ -182,7 +183,7 @@ class TestMessageFeedbackApi: class TestMessageMoreLikeThisApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -221,7 +222,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(NotCompletionAppError): method(installed_app, "mid") - def test_more_like_this_disabled(self, app): + def test_more_like_this_disabled(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -243,7 +244,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(AppMoreLikeThisDisabledError): method(installed_app, "mid") - def test_message_not_exists_more_like_this(self, app): + def test_message_not_exists_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -265,7 +266,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(NotFound): method(installed_app, "mid") - def test_provider_not_init_more_like_this(self, app): + def test_provider_not_init_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -287,7 +288,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderNotInitializeError): method(installed_app, "mid") - def test_quota_exceeded_more_like_this(self, app): + def test_quota_exceeded_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -309,7 +310,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderQuotaExceededError): method(installed_app, "mid") - def test_model_not_support_more_like_this(self, app): + def test_model_not_support_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -331,7 +332,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(installed_app, "mid") - def test_invoke_error_more_like_this(self, app): + def test_invoke_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -353,7 +354,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(CompletionRequestError): method(installed_app, "mid") - def test_unexpected_error_more_like_this(self, app): + def test_unexpected_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 76c863577a..557fded37e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock, patch +from flask import Flask + import controllers.console.explore.recommended_app as module from models.model import AppMode, IconType @@ -11,7 +13,7 @@ def unwrap(func): class TestRecommendedAppListApi: - def test_get_with_language_param(self, app): + def test_get_with_language_param(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -31,7 +33,7 @@ class TestRecommendedAppListApi: service_mock.assert_called_once_with("en-US") assert result == result_data - def test_get_fallback_to_user_language(self, app): + def test_get_fallback_to_user_language(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -51,7 +53,7 @@ class TestRecommendedAppListApi: service_mock.assert_called_once_with("fr-FR") assert result == result_data - def test_get_fallback_to_default_language(self, app): + def test_get_fallback_to_default_language(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -73,7 +75,7 @@ class TestRecommendedAppListApi: class TestRecommendedAppApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.RecommendedAppApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index bb7cdd55c4..71241890e9 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, PropertyMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.saved_message as module @@ -42,7 +43,7 @@ def payload_patch(): class TestSavedMessageListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.SavedMessageListApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 3625056af9..14f00e6295 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -88,7 +89,7 @@ def valid_parameters(): class TestTrialAppWorkflowRunApi: - def test_not_workflow_app(self, app): + def test_not_workflow_app(self, app: Flask): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -224,7 +225,7 @@ class TestTrialAppWorkflowRunApi: class TestTrialChatApi: - def test_not_chat_app(self, app): + def test_not_chat_app(self, app: Flask): api = module.TrialChatApi() method = unwrap(api.post) @@ -408,7 +409,7 @@ class TestTrialChatApi: class TestTrialCompletionApi: - def test_not_completion_app(self, app): + def test_not_completion_app(self, app: Flask): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -560,7 +561,7 @@ class TestTrialCompletionApi: class TestTrialMessageSuggestedQuestionApi: - def test_not_chat_app(self, app): + def test_not_chat_app(self, app: Flask): api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) @@ -952,7 +953,7 @@ class TestTrialAppWorkflowTaskStopApi: class TestTrialSitApi: - def test_no_site(self, app): + def test_no_site(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) app_model = MagicMock() @@ -963,7 +964,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_archived_tenant(self, app): + def test_archived_tenant(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) @@ -978,7 +979,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_success(self, app): + def test_success(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index a26d171649..8b47da25fb 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -73,7 +73,7 @@ def payload_patch(): class TestTagListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = TagListApi() method = unwrap(api.get) @@ -124,7 +124,7 @@ class TestTagListApi: assert result["name"] == "test-tag" assert result["binding_count"] == "0" - def test_post_forbidden(self, app, readonly_user, payload_patch): + def test_post_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagListApi() method = unwrap(api.post) @@ -170,7 +170,7 @@ class TestTagUpdateDeleteApi: assert status == 200 assert result["binding_count"] == "3" - def test_patch_forbidden(self, app, readonly_user, payload_patch): + def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagUpdateDeleteApi() method = unwrap(api.patch) @@ -231,7 +231,7 @@ class TestTagBindingCollectionApi: assert status == 200 assert result["result"] == "success" - def test_create_forbidden(self, app, readonly_user, payload_patch): + def test_create_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagBindingCollectionApi() method = unwrap(api.post) @@ -275,7 +275,7 @@ class TestTagBindingRemoveApi: assert status == 200 assert result["result"] == "success" - def test_remove_forbidden(self, app, readonly_user, payload_patch): + def test_remove_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagBindingRemoveApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index 5df9daa7f8..eebc6f9d60 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -82,7 +82,7 @@ def mock_file_service(mock_db): class TestFileApiGet: - def test_get_upload_config(self, app): + def test_get_upload_config(self, app: Flask): api = FileApi() get_method = unwrap(api.get) @@ -290,7 +290,7 @@ class TestFilePreviewApi: class TestFileSupportTypeApi: - def test_get_supported_types(self, app): + def test_get_supported_types(self, app: Flask): api = FileSupportTypeApi() get_method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 0b1a32581a..4b4f968c8f 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -58,7 +58,7 @@ class TestChangeEmailSend: mock_get_change_data, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") @@ -107,7 +107,7 @@ class TestChangeEmailSend: mock_get_change_data, mock_current_account, mock_db, - app, + app: Flask, ): """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" from controllers.console.auth.error import InvalidTokenError @@ -155,7 +155,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") @@ -214,7 +214,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) @@ -267,7 +267,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): """A token whose phase marker is a string but not a known transition must be rejected.""" from controllers.console.auth.error import InvalidTokenError @@ -316,7 +316,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" from controllers.console.auth.error import InvalidTokenError @@ -366,7 +366,7 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") @@ -418,7 +418,7 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" from controllers.console.auth.error import InvalidTokenError @@ -471,7 +471,7 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): """A verified token for address A must not be replayed to change to address B.""" from controllers.console.auth.error import InvalidTokenError @@ -547,7 +547,7 @@ class TestAccountServiceSendChangeEmailEmail: class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") - def test_should_normalize_feedback_email(self, mock_update, mock_db, app): + def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask): with app.test_request_context( "/account/delete/feedback", method="POST", @@ -563,7 +563,7 @@ class TestCheckEmailUnique: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask): mock_is_freeze.return_value = False mock_check_unique.return_value = True diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 811bf5b1e7..412d6a6c52 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -43,7 +43,7 @@ class TestMemberInviteEmailApi: mock_current_account, mock_invite_member, mock_get_features, - app, + app: Flask, ): mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index bbe9d09521..064726da05 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -41,7 +42,7 @@ def unwrap(func): class TestAccountInitApi: - def test_init_success(self, app): + def test_init_success(self, app: Flask): api = AccountInitApi() method = unwrap(api.post) @@ -64,7 +65,7 @@ class TestAccountInitApi: assert resp["result"] == "success" - def test_init_already_initialized(self, app): + def test_init_already_initialized(self, app: Flask): api = AccountInitApi() method = unwrap(api.post) @@ -79,7 +80,7 @@ class TestAccountInitApi: class TestAccountProfileApi: - def test_get_profile_success(self, app): + def test_get_profile_success(self, app: Flask): api = AccountProfileApi() method = unwrap(api.get) @@ -140,7 +141,7 @@ class TestAccountUpdateApis: class TestAccountAvatarApiGet: """GET /account/avatar must not sign arbitrary upload_file IDs (IDOR).""" - def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app): + def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -172,7 +173,7 @@ class TestAccountAvatarApiGet: assert result == {"avatar_url": "https://signed/example"} sign_mock.assert_called_once_with(upload_file_id=file_id) - def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app): + def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -204,7 +205,7 @@ class TestAccountAvatarApiGet: sign_mock.assert_not_called() - def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app): + def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -236,7 +237,7 @@ class TestAccountAvatarApiGet: sign_mock.assert_not_called() - def test_get_avatar_https_pass_through_without_signing(self, app): + def test_get_avatar_https_pass_through_without_signing(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -263,7 +264,7 @@ class TestAccountAvatarApiGet: class TestAccountPasswordApi: - def test_password_success(self, app): + def test_password_success(self, app: Flask): api = AccountPasswordApi() method = unwrap(api.post) @@ -292,7 +293,7 @@ class TestAccountPasswordApi: assert result["id"] == "u1" - def test_password_wrong_current(self, app): + def test_password_wrong_current(self, app: Flask): api = AccountPasswordApi() method = unwrap(api.post) @@ -317,7 +318,7 @@ class TestAccountPasswordApi: class TestAccountIntegrateApi: - def test_get_integrates(self, app): + def test_get_integrates(self, app: Flask): api = AccountIntegrateApi() method = unwrap(api.get) @@ -336,7 +337,7 @@ class TestAccountIntegrateApi: class TestAccountDeleteApi: - def test_delete_verify_success(self, app): + def test_delete_verify_success(self, app: Flask): api = AccountDeleteVerifyApi() method = unwrap(api.get) @@ -358,7 +359,7 @@ class TestAccountDeleteApi: assert result["result"] == "success" - def test_delete_invalid_code(self, app): + def test_delete_invalid_code(self, app: Flask): api = AccountDeleteApi() method = unwrap(api.post) @@ -379,7 +380,7 @@ class TestAccountDeleteApi: class TestChangeEmailApis: - def test_check_email_code_invalid(self, app): + def test_check_email_code_invalid(self, app: Flask): api = ChangeEmailCheckApi() method = unwrap(api.post) @@ -405,7 +406,7 @@ class TestChangeEmailApis: with pytest.raises(EmailCodeError): method(api) - def test_reset_email_already_used(self, app): + def test_reset_email_already_used(self, app: Flask): api = ChangeEmailResetApi() method = unwrap(api.post) @@ -427,7 +428,7 @@ class TestChangeEmailApis: class TestCheckEmailUniqueApi: - def test_email_unique_success(self, app): + def test_email_unique_success(self, app: Flask): api = CheckEmailUnique() method = unwrap(api.post) @@ -448,7 +449,7 @@ class TestCheckEmailUniqueApi: assert result["result"] == "success" - def test_email_in_freeze(self, app): + def test_email_in_freeze(self, app: Flask): api = CheckEmailUnique() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py index b4e03f681d..eb0ca15d2e 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.error import AccountNotFound from controllers.console.workspace.agent_providers import ( @@ -16,7 +17,7 @@ def unwrap(func): class TestAgentProviderListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -39,7 +40,7 @@ class TestAgentProviderListApi: assert result == providers - def test_get_empty_list(self, app): + def test_get_empty_list(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -61,7 +62,7 @@ class TestAgentProviderListApi: assert result == [] - def test_get_account_not_found(self, app): + def test_get_account_not_found(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -77,7 +78,7 @@ class TestAgentProviderListApi: class TestAgentProviderApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) @@ -101,7 +102,7 @@ class TestAgentProviderApi: assert result == provider_data - def test_get_provider_not_found(self, app): + def test_get_provider_not_found(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) @@ -124,7 +125,7 @@ class TestAgentProviderApi: assert result is None - def test_get_account_not_found(self, app): + def test_get_account_not_found(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py index 0b3d7ef6d7..ed7b2d606f 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console import console_ns from controllers.console.workspace.endpoint import ( @@ -39,7 +40,7 @@ def patch_current_account(user_and_tenant): @pytest.mark.usefixtures("patch_current_account") class TestEndpointCollectionApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = EndpointCollectionApi() method = unwrap(api.post) @@ -57,7 +58,7 @@ class TestEndpointCollectionApi: assert result["success"] is True - def test_create_permission_denied(self, app): + def test_create_permission_denied(self, app: Flask): api = EndpointCollectionApi() method = unwrap(api.post) @@ -77,7 +78,7 @@ class TestEndpointCollectionApi: with pytest.raises(ValueError): method(api) - def test_create_validation_error(self, app): + def test_create_validation_error(self, app: Flask): api = EndpointCollectionApi() method = unwrap(api.post) @@ -96,7 +97,7 @@ class TestEndpointCollectionApi: @pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointCreateApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = DeprecatedEndpointCreateApi() method = unwrap(api.post) @@ -117,7 +118,7 @@ class TestDeprecatedEndpointCreateApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = EndpointListApi() method = unwrap(api.get) @@ -130,7 +131,7 @@ class TestEndpointListApi: assert "endpoints" in result assert len(result["endpoints"]) == 1 - def test_list_invalid_query(self, app): + def test_list_invalid_query(self, app: Flask): api = EndpointListApi() method = unwrap(api.get) @@ -143,7 +144,7 @@ class TestEndpointListApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointListForSinglePluginApi: - def test_list_for_plugin_success(self, app): + def test_list_for_plugin_success(self, app: Flask): api = EndpointListForSinglePluginApi() method = unwrap(api.get) @@ -158,7 +159,7 @@ class TestEndpointListForSinglePluginApi: assert "endpoints" in result - def test_list_for_plugin_missing_param(self, app): + def test_list_for_plugin_missing_param(self, app: Flask): api = EndpointListForSinglePluginApi() method = unwrap(api.get) @@ -171,7 +172,7 @@ class TestEndpointListForSinglePluginApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointItemApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = EndpointItemApi() method = unwrap(api.delete) @@ -187,7 +188,7 @@ class TestEndpointItemApi: assert result["success"] is True mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1") - def test_delete_service_failure(self, app): + def test_delete_service_failure(self, app: Flask): api = EndpointItemApi() method = unwrap(api.delete) @@ -199,7 +200,7 @@ class TestEndpointItemApi: assert result["success"] is False - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = EndpointItemApi() method = unwrap(api.patch) @@ -226,7 +227,7 @@ class TestEndpointItemApi: settings={"x": 1}, ) - def test_update_validation_error(self, app): + def test_update_validation_error(self, app: Flask): api = EndpointItemApi() method = unwrap(api.patch) @@ -238,7 +239,7 @@ class TestEndpointItemApi: with pytest.raises(ValueError): method(api, "e1") - def test_update_service_failure(self, app): + def test_update_service_failure(self, app: Flask): api = EndpointItemApi() method = unwrap(api.patch) @@ -258,7 +259,7 @@ class TestEndpointItemApi: @pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) @@ -272,7 +273,7 @@ class TestDeprecatedEndpointDeleteApi: assert result["success"] is True - def test_delete_invalid_payload(self, app): + def test_delete_invalid_payload(self, app: Flask): api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) @@ -282,7 +283,7 @@ class TestDeprecatedEndpointDeleteApi: with pytest.raises(ValueError): method(api) - def test_delete_service_failure(self, app): + def test_delete_service_failure(self, app: Flask): api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) @@ -299,7 +300,7 @@ class TestDeprecatedEndpointDeleteApi: @pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) @@ -317,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi: assert result["success"] is True - def test_update_validation_error(self, app): + def test_update_validation_error(self, app: Flask): api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) @@ -329,7 +330,7 @@ class TestDeprecatedEndpointUpdateApi: with pytest.raises(ValueError): method(api) - def test_update_service_failure(self, app): + def test_update_service_failure(self, app: Flask): api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) @@ -380,7 +381,7 @@ class TestEndpointRouteMetadata: @pytest.mark.usefixtures("patch_current_account") class TestEndpointEnableApi: - def test_enable_success(self, app): + def test_enable_success(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -394,7 +395,7 @@ class TestEndpointEnableApi: assert result["success"] is True - def test_enable_invalid_payload(self, app): + def test_enable_invalid_payload(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -404,7 +405,7 @@ class TestEndpointEnableApi: with pytest.raises(ValueError): method(api) - def test_enable_service_failure(self, app): + def test_enable_service_failure(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -421,7 +422,7 @@ class TestEndpointEnableApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointDisableApi: - def test_disable_success(self, app): + def test_disable_success(self, app: Flask): api = EndpointDisableApi() method = unwrap(api.post) @@ -435,7 +436,7 @@ class TestEndpointDisableApi: assert result["success"] is True - def test_disable_invalid_payload(self, app): + def test_disable_invalid_payload(self, app: Flask): api = EndpointDisableApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index 718b57ba6b..0788ff603c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import HTTPException import services @@ -34,7 +35,7 @@ def unwrap(func): class TestMemberListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = MemberListApi() method = unwrap(api.get) @@ -59,7 +60,7 @@ class TestMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 - def test_get_no_tenant(self, app): + def test_get_no_tenant(self, app: Flask): api = MemberListApi() method = unwrap(api.get) @@ -74,7 +75,7 @@ class TestMemberListApi: class TestMemberInviteEmailApi: - def test_invite_success(self, app): + def test_invite_success(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestMemberInviteEmailApi: assert status == 201 assert result["result"] == "success" - def test_invite_limit_exceeded(self, app): + def test_invite_limit_exceeded(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -123,7 +124,7 @@ class TestMemberInviteEmailApi: with pytest.raises(WorkspaceMembersLimitExceeded): method(api) - def test_invite_already_member(self, app): + def test_invite_already_member(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -151,7 +152,7 @@ class TestMemberInviteEmailApi: assert result["invitation_results"][0]["status"] == "success" - def test_invite_invalid_role(self, app): + def test_invite_invalid_role(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -166,7 +167,7 @@ class TestMemberInviteEmailApi: assert status == 400 assert result["code"] == "invalid-role" - def test_invite_generic_exception(self, app): + def test_invite_generic_exception(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -196,7 +197,7 @@ class TestMemberInviteEmailApi: class TestMemberCancelInviteApi: - def test_cancel_success(self, app): + def test_cancel_success(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -216,7 +217,7 @@ class TestMemberCancelInviteApi: assert status == 200 assert result["result"] == "success" - def test_cancel_not_found(self, app): + def test_cancel_not_found(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -233,7 +234,7 @@ class TestMemberCancelInviteApi: with pytest.raises(HTTPException): method(api, "x") - def test_cancel_cannot_operate_self(self, app): + def test_cancel_cannot_operate_self(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -255,7 +256,7 @@ class TestMemberCancelInviteApi: assert status == 400 - def test_cancel_no_permission(self, app): + def test_cancel_no_permission(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -277,7 +278,7 @@ class TestMemberCancelInviteApi: assert status == 403 - def test_cancel_member_not_in_tenant(self, app): + def test_cancel_member_not_in_tenant(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -301,7 +302,7 @@ class TestMemberCancelInviteApi: class TestMemberUpdateRoleApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -324,7 +325,7 @@ class TestMemberUpdateRoleApi: assert result["result"] == "success" - def test_update_invalid_role(self, app): + def test_update_invalid_role(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -335,7 +336,7 @@ class TestMemberUpdateRoleApi: assert status == 400 - def test_update_member_not_found(self, app): + def test_update_member_not_found(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -354,7 +355,7 @@ class TestMemberUpdateRoleApi: class TestDatasetOperatorMemberListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetOperatorMemberListApi() method = unwrap(api.get) @@ -381,7 +382,7 @@ class TestDatasetOperatorMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 - def test_get_no_tenant(self, app): + def test_get_no_tenant(self, app: Flask): api = DatasetOperatorMemberListApi() method = unwrap(api.get) @@ -396,7 +397,7 @@ class TestDatasetOperatorMemberListApi: class TestSendOwnerTransferEmailApi: - def test_send_success(self, app): + def test_send_success(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -419,7 +420,7 @@ class TestSendOwnerTransferEmailApi: assert result["result"] == "success" - def test_send_ip_limit(self, app): + def test_send_ip_limit(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -433,7 +434,7 @@ class TestSendOwnerTransferEmailApi: with pytest.raises(EmailSendIpLimitError): method(api) - def test_send_not_owner(self, app): + def test_send_not_owner(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -452,7 +453,7 @@ class TestSendOwnerTransferEmailApi: class TestOwnerTransferCheckApi: - def test_check_invalid_code(self, app): + def test_check_invalid_code(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -477,7 +478,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(EmailCodeError): method(api) - def test_rate_limited(self, app): + def test_rate_limited(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -498,7 +499,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(OwnerTransferLimitError): method(api) - def test_invalid_token(self, app): + def test_invalid_token(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -520,7 +521,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(InvalidTokenError): method(api) - def test_invalid_email(self, app): + def test_invalid_email(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -547,7 +548,7 @@ class TestOwnerTransferCheckApi: class TestOwnerTransferApi: - def test_transfer_self(self, app): + def test_transfer_self(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) @@ -564,7 +565,7 @@ class TestOwnerTransferApi: with pytest.raises(CannotTransferOwnerToSelfError): method(api, "1") - def test_invalid_token(self, app): + def test_invalid_token(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) @@ -582,7 +583,7 @@ class TestOwnerTransferApi: with pytest.raises(InvalidTokenError): method(api, "2") - def test_member_not_in_tenant(self, app): + def test_member_not_in_tenant(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index 168479af1e..e836a3cc55 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -26,7 +27,7 @@ def unwrap(func): class TestModelProviderListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ModelProviderListApi() method = unwrap(api.get) @@ -47,7 +48,7 @@ class TestModelProviderListApi: class TestModelProviderCredentialApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.get) @@ -66,7 +67,7 @@ class TestModelProviderCredentialApi: assert "credentials" in result - def test_get_invalid_uuid(self, app): + def test_get_invalid_uuid(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.get) @@ -80,7 +81,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValidationError): method(api, provider="openai") - def test_post_create_success(self, app): + def test_post_create_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.post) @@ -102,7 +103,7 @@ class TestModelProviderCredentialApi: assert result["result"] == "success" assert status == 201 - def test_post_create_validation_error(self, app): + def test_post_create_validation_error(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.post) @@ -122,7 +123,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValueError): method(api, provider="openai") - def test_put_update_success(self, app): + def test_put_update_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.put) @@ -143,7 +144,7 @@ class TestModelProviderCredentialApi: assert result["result"] == "success" - def test_put_invalid_uuid(self, app): + def test_put_invalid_uuid(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.put) @@ -159,7 +160,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValidationError): method(api, provider="openai") - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.delete) @@ -183,7 +184,7 @@ class TestModelProviderCredentialApi: class TestModelProviderCredentialSwitchApi: - def test_switch_success(self, app): + def test_switch_success(self, app: Flask): api = ModelProviderCredentialSwitchApi() method = unwrap(api.post) @@ -204,7 +205,7 @@ class TestModelProviderCredentialSwitchApi: assert result["result"] == "success" - def test_switch_invalid_uuid(self, app): + def test_switch_invalid_uuid(self, app: Flask): api = ModelProviderCredentialSwitchApi() method = unwrap(api.post) @@ -222,7 +223,7 @@ class TestModelProviderCredentialSwitchApi: class TestModelProviderValidateApi: - def test_validate_success(self, app): + def test_validate_success(self, app: Flask): api = ModelProviderValidateApi() method = unwrap(api.post) @@ -243,7 +244,7 @@ class TestModelProviderValidateApi: assert result["result"] == "success" - def test_validate_failure(self, app): + def test_validate_failure(self, app: Flask): api = ModelProviderValidateApi() method = unwrap(api.post) @@ -266,7 +267,7 @@ class TestModelProviderValidateApi: class TestModelProviderIconApi: - def test_icon_success(self, app): + def test_icon_success(self, app: Flask): api = ModelProviderIconApi() with ( @@ -280,7 +281,7 @@ class TestModelProviderIconApi: assert response.mimetype == "image/png" - def test_icon_not_found(self, app): + def test_icon_not_found(self, app: Flask): api = ModelProviderIconApi() with ( @@ -295,7 +296,7 @@ class TestModelProviderIconApi: class TestPreferredProviderTypeUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = PreferredProviderTypeUpdateApi() method = unwrap(api.post) @@ -316,7 +317,7 @@ class TestPreferredProviderTypeUpdateApi: assert result["result"] == "success" - def test_invalid_enum(self, app): + def test_invalid_enum(self, app: Flask): api = PreferredProviderTypeUpdateApi() method = unwrap(api.post) @@ -334,7 +335,7 @@ class TestPreferredProviderTypeUpdateApi: class TestModelProviderPaymentCheckoutUrlApi: - def test_checkout_success(self, app): + def test_checkout_success(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) @@ -359,7 +360,7 @@ class TestModelProviderPaymentCheckoutUrlApi: assert "url" in result - def test_invalid_provider(self, app): + def test_invalid_provider(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) @@ -367,7 +368,7 @@ class TestModelProviderPaymentCheckoutUrlApi: with pytest.raises(ValueError): method(api, provider="openai") - def test_permission_denied(self, app): + def test_permission_denied(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index f0d32f81fb..4246e3c04c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -72,7 +72,7 @@ class TestDefaultModelApi: assert result["result"] == "success" - def test_get_returns_empty_when_no_default(self, app): + def test_get_returns_empty_when_no_default(self, app: Flask): api = DefaultModelApi() method = unwrap(api.get) @@ -154,7 +154,7 @@ class TestModelProviderModelApi: assert status == 204 - def test_get_models_returns_empty(self, app): + def test_get_models_returns_empty(self, app: Flask): api = ModelProviderModelApi() method = unwrap(api.get) @@ -224,7 +224,7 @@ class TestModelProviderModelCredentialApi: assert status == 201 - def test_get_empty_credentials(self, app): + def test_get_empty_credentials(self, app: Flask): api = ModelProviderModelCredentialApi() method = unwrap(api.get) @@ -242,7 +242,7 @@ class TestModelProviderModelCredentialApi: assert result["credentials"] == {} - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ModelProviderModelCredentialApi() method = unwrap(api.delete) @@ -416,7 +416,7 @@ class TestParameterAndAvailableModels: assert "data" in result - def test_empty_rules(self, app): + def test_empty_rules(self, app: Flask): api = ModelProviderModelParameterRuleApi() method = unwrap(api.get) @@ -431,7 +431,7 @@ class TestParameterAndAvailableModels: assert result["data"] == [] - def test_no_models(self, app): + def test_no_models(self, app: Flask): api = ModelProviderAvailableModelApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index ce5fd1c466..d01bf7d668 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -2,6 +2,7 @@ import io from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -61,7 +62,7 @@ def tenant(): class TestPluginListLatestVersionsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginListLatestVersionsApi() method = unwrap(api.post) @@ -77,7 +78,7 @@ class TestPluginListLatestVersionsApi: assert "versions" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginListLatestVersionsApi() method = unwrap(api.post) @@ -95,7 +96,7 @@ class TestPluginListLatestVersionsApi: class TestPluginDebuggingKeyApi: - def test_debugging_key_success(self, app): + def test_debugging_key_success(self, app: Flask): api = PluginDebuggingKeyApi() method = unwrap(api.get) @@ -108,7 +109,7 @@ class TestPluginDebuggingKeyApi: assert result["key"] == "k" - def test_debugging_key_error(self, app): + def test_debugging_key_error(self, app: Flask): api = PluginDebuggingKeyApi() method = unwrap(api.get) @@ -125,7 +126,7 @@ class TestPluginDebuggingKeyApi: class TestPluginListApi: - def test_plugin_list(self, app): + def test_plugin_list(self, app: Flask): api = PluginListApi() method = unwrap(api.get) @@ -142,7 +143,7 @@ class TestPluginListApi: class TestPluginIconApi: - def test_plugin_icon(self, app): + def test_plugin_icon(self, app: Flask): api = PluginIconApi() method = unwrap(api.get) @@ -156,7 +157,7 @@ class TestPluginIconApi: class TestPluginAssetApi: - def test_plugin_asset(self, app): + def test_plugin_asset(self, app: Flask): api = PluginAssetApi() method = unwrap(api.get) @@ -171,7 +172,7 @@ class TestPluginAssetApi: class TestPluginUploadFromPkgApi: - def test_upload_pkg_success(self, app): + def test_upload_pkg_success(self, app: Flask): api = PluginUploadFromPkgApi() method = unwrap(api.post) @@ -188,7 +189,7 @@ class TestPluginUploadFromPkgApi: assert result["ok"] is True - def test_upload_pkg_too_large(self, app): + def test_upload_pkg_too_large(self, app: Flask): api = PluginUploadFromPkgApi() method = unwrap(api.post) @@ -210,7 +211,7 @@ class TestPluginUploadFromPkgApi: class TestPluginInstallFromPkgApi: - def test_install_from_pkg(self, app): + def test_install_from_pkg(self, app: Flask): api = PluginInstallFromPkgApi() method = unwrap(api.post) @@ -229,7 +230,7 @@ class TestPluginInstallFromPkgApi: class TestPluginUninstallApi: - def test_uninstall(self, app): + def test_uninstall(self, app: Flask): api = PluginUninstallApi() method = unwrap(api.post) @@ -246,7 +247,7 @@ class TestPluginUninstallApi: class TestPluginChangePermissionApi: - def test_change_permission_forbidden(self, app): + def test_change_permission_forbidden(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) @@ -264,7 +265,7 @@ class TestPluginChangePermissionApi: with pytest.raises(Forbidden): method(api) - def test_change_permission_success(self, app): + def test_change_permission_success(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) @@ -286,7 +287,7 @@ class TestPluginChangePermissionApi: class TestPluginFetchPermissionApi: - def test_fetch_permission_default(self, app): + def test_fetch_permission_default(self, app: Flask): api = PluginFetchPermissionApi() method = unwrap(api.get) @@ -319,7 +320,7 @@ class TestPluginFetchDynamicSelectOptionsApi: class TestPluginReadmeApi: - def test_fetch_readme(self, app): + def test_fetch_readme(self, app: Flask): api = PluginReadmeApi() method = unwrap(api.get) @@ -334,7 +335,7 @@ class TestPluginReadmeApi: class TestPluginListInstallationsFromIdsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -352,7 +353,7 @@ class TestPluginListInstallationsFromIdsApi: assert "plugins" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -371,7 +372,7 @@ class TestPluginListInstallationsFromIdsApi: class TestPluginUploadFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -388,7 +389,7 @@ class TestPluginUploadFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -407,7 +408,7 @@ class TestPluginUploadFromGithubApi: class TestPluginUploadFromBundleApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUploadFromBundleApi() method = unwrap(api.post) @@ -430,7 +431,7 @@ class TestPluginUploadFromBundleApi: assert result["ok"] is True - def test_too_large(self, app): + def test_too_large(self, app: Flask): api = PluginUploadFromBundleApi() method = unwrap(api.post) @@ -458,7 +459,7 @@ class TestPluginUploadFromBundleApi: class TestPluginInstallFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginInstallFromGithubApi() method = unwrap(api.post) @@ -478,7 +479,7 @@ class TestPluginInstallFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginInstallFromGithubApi() method = unwrap(api.post) @@ -502,7 +503,7 @@ class TestPluginInstallFromGithubApi: class TestPluginInstallFromMarketplaceApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginInstallFromMarketplaceApi() method = unwrap(api.post) @@ -520,7 +521,7 @@ class TestPluginInstallFromMarketplaceApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginInstallFromMarketplaceApi() method = unwrap(api.post) @@ -539,7 +540,7 @@ class TestPluginInstallFromMarketplaceApi: class TestPluginFetchMarketplacePkgApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchMarketplacePkgApi() method = unwrap(api.get) @@ -552,7 +553,7 @@ class TestPluginFetchMarketplacePkgApi: assert "manifest" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchMarketplacePkgApi() method = unwrap(api.get) @@ -569,7 +570,7 @@ class TestPluginFetchMarketplacePkgApi: class TestPluginFetchManifestApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchManifestApi() method = unwrap(api.get) @@ -585,7 +586,7 @@ class TestPluginFetchManifestApi: assert "manifest" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchManifestApi() method = unwrap(api.get) @@ -602,7 +603,7 @@ class TestPluginFetchManifestApi: class TestPluginFetchInstallTasksApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchInstallTasksApi() method = unwrap(api.get) @@ -615,7 +616,7 @@ class TestPluginFetchInstallTasksApi: assert "tasks" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchInstallTasksApi() method = unwrap(api.get) @@ -632,7 +633,7 @@ class TestPluginFetchInstallTasksApi: class TestPluginFetchInstallTaskApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchInstallTaskApi() method = unwrap(api.get) @@ -645,7 +646,7 @@ class TestPluginFetchInstallTaskApi: assert "task" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchInstallTaskApi() method = unwrap(api.get) @@ -662,7 +663,7 @@ class TestPluginFetchInstallTaskApi: class TestPluginDeleteInstallTaskApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteInstallTaskApi() method = unwrap(api.post) @@ -675,7 +676,7 @@ class TestPluginDeleteInstallTaskApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteInstallTaskApi() method = unwrap(api.post) @@ -692,7 +693,7 @@ class TestPluginDeleteInstallTaskApi: class TestPluginDeleteAllInstallTaskItemsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteAllInstallTaskItemsApi() method = unwrap(api.post) @@ -707,7 +708,7 @@ class TestPluginDeleteAllInstallTaskItemsApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteAllInstallTaskItemsApi() method = unwrap(api.post) @@ -724,7 +725,7 @@ class TestPluginDeleteAllInstallTaskItemsApi: class TestPluginDeleteInstallTaskItemApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteInstallTaskItemApi() method = unwrap(api.post) @@ -737,7 +738,7 @@ class TestPluginDeleteInstallTaskItemApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteInstallTaskItemApi() method = unwrap(api.post) @@ -754,7 +755,7 @@ class TestPluginDeleteInstallTaskItemApi: class TestPluginUpgradeFromMarketplaceApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -775,7 +776,7 @@ class TestPluginUpgradeFromMarketplaceApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -797,7 +798,7 @@ class TestPluginUpgradeFromMarketplaceApi: class TestPluginUpgradeFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUpgradeFromGithubApi() method = unwrap(api.post) @@ -821,7 +822,7 @@ class TestPluginUpgradeFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUpgradeFromGithubApi() method = unwrap(api.post) @@ -846,7 +847,7 @@ class TestPluginUpgradeFromGithubApi: class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) @@ -873,7 +874,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: assert result["options"] == [1] - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) @@ -901,7 +902,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: class TestPluginChangePreferencesApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginChangePreferencesApi() method = unwrap(api.post) @@ -931,7 +932,7 @@ class TestPluginChangePreferencesApi: assert result["success"] is True - def test_permission_fail(self, app): + def test_permission_fail(self, app: Flask): api = PluginChangePreferencesApi() method = unwrap(api.post) @@ -962,7 +963,7 @@ class TestPluginChangePreferencesApi: class TestPluginFetchPreferencesApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchPreferencesApi() method = unwrap(api.get) @@ -996,7 +997,7 @@ class TestPluginFetchPreferencesApi: class TestPluginAutoUpgradeExcludePluginApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginAutoUpgradeExcludePluginApi() method = unwrap(api.post) @@ -1011,7 +1012,7 @@ class TestPluginAutoUpgradeExcludePluginApi: assert result["success"] is True - def test_fail(self, app): + def test_fail(self, app: Flask): api = PluginAutoUpgradeExcludePluginApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index e82a29f045..a52518c2d2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Unauthorized @@ -37,7 +38,7 @@ def unwrap(func): class TestTenantListApi: - def test_get_success_saas_path(self, app): + def test_get_success_saas_path(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -85,7 +86,7 @@ class TestTenantListApi: get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) get_features_mock.assert_not_called() - def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app): + def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app: Flask): """Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used. billing.enabled is mocked False to prove the endpoint does not gate on it for this path @@ -140,7 +141,7 @@ class TestTenantListApi: get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) get_features_mock.assert_called_once_with("t2") - def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app): + def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask): """Test fallback to FeatureService when bulk billing returns empty result. BillingService.get_plan_bulk catches exceptions internally and returns empty dict, @@ -197,7 +198,7 @@ class TestTenantListApi: assert get_features_mock.call_count == 2 logger_warning_mock.assert_called_once() - def test_get_billing_disabled_community_path(self, app): + def test_get_billing_disabled_community_path(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -236,7 +237,7 @@ class TestTenantListApi: assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX get_features_mock.assert_called_once_with("t1") - def test_get_enterprise_only_skips_feature_service(self, app): + def test_get_enterprise_only_skips_feature_service(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -276,7 +277,7 @@ class TestTenantListApi: assert result["workspaces"][1]["current"] is True get_features_mock.assert_not_called() - def test_get_enterprise_only_with_empty_tenants(self, app): + def test_get_enterprise_only_with_empty_tenants(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -302,7 +303,7 @@ class TestTenantListApi: class TestWorkspaceListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = WorkspaceListApi() method = unwrap(api.get) @@ -324,7 +325,7 @@ class TestWorkspaceListApi: assert result["total"] == 1 assert result["has_more"] is False - def test_get_has_next_true(self, app): + def test_get_has_next_true(self, app: Flask): api = WorkspaceListApi() method = unwrap(api.get) @@ -355,7 +356,7 @@ class TestWorkspaceListApi: class TestTenantApi: - def test_post_active_tenant(self, app): + def test_post_active_tenant(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -375,7 +376,7 @@ class TestTenantApi: assert status == 200 assert result["id"] == "t1" - def test_post_archived_with_switch(self, app): + def test_post_archived_with_switch(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -397,7 +398,7 @@ class TestTenantApi: assert result["id"] == "new" - def test_post_archived_no_tenant(self, app): + def test_post_archived_no_tenant(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -411,7 +412,7 @@ class TestTenantApi: with pytest.raises(Unauthorized): method(api) - def test_post_info_path(self, app): + def test_post_info_path(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -454,7 +455,7 @@ class TestTenantInfoResponse: class TestSwitchWorkspaceApi: - def test_switch_success(self, app): + def test_switch_success(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -477,7 +478,7 @@ class TestSwitchWorkspaceApi: assert result["result"] == "success" - def test_switch_not_linked(self, app): + def test_switch_not_linked(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -493,7 +494,7 @@ class TestSwitchWorkspaceApi: with pytest.raises(AccountNotLinkTenantError): method(api) - def test_switch_tenant_not_found(self, app): + def test_switch_tenant_not_found(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -515,7 +516,7 @@ class TestSwitchWorkspaceApi: class TestCustomConfigWorkspaceApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CustomConfigWorkspaceApi() method = unwrap(api.post) @@ -538,7 +539,7 @@ class TestCustomConfigWorkspaceApi: assert result["result"] == "success" - def test_logo_fallback(self, app): + def test_logo_fallback(self, app: Flask): api = CustomConfigWorkspaceApi() method = unwrap(api.post) @@ -569,7 +570,7 @@ class TestCustomConfigWorkspaceApi: class TestWebappLogoWorkspaceApi: - def test_no_file(self, app): + def test_no_file(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -582,7 +583,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(NoFileUploadedError): method(api) - def test_too_many_files(self, app): + def test_too_many_files(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -601,7 +602,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(TooManyFilesError): method(api) - def test_invalid_extension(self, app): + def test_invalid_extension(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -616,7 +617,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(UnsupportedFileTypeError): method(api) - def test_upload_success(self, app): + def test_upload_success(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -648,7 +649,7 @@ class TestWebappLogoWorkspaceApi: assert status == 201 assert result["id"] == "file1" - def test_filename_missing(self, app): + def test_filename_missing(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -672,7 +673,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(FilenameNotExistsError): method(api) - def test_file_too_large(self, app): + def test_file_too_large(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -701,7 +702,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(FileTooLargeError): method(api) - def test_service_unsupported_file(self, app): + def test_service_unsupported_file(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -732,7 +733,7 @@ class TestWebappLogoWorkspaceApi: class TestWorkspaceInfoApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = WorkspaceInfoApi() method = unwrap(api.post) @@ -756,7 +757,7 @@ class TestWorkspaceInfoApi: assert result["result"] == "success" - def test_no_current_tenant(self, app): + def test_no_current_tenant(self, app: Flask): api = WorkspaceInfoApi() method = unwrap(api.post) @@ -774,7 +775,7 @@ class TestWorkspaceInfoApi: class TestWorkspacePermissionApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = WorkspacePermissionApi() method = unwrap(api.get) @@ -799,7 +800,7 @@ class TestWorkspacePermissionApi: assert status == 200 assert result["workspace_id"] == "t1" - def test_no_current_tenant(self, app): + def test_no_current_tenant(self, app: Flask): api = WorkspacePermissionApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f5d93b5ac3..ae0edcf382 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -41,7 +41,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_for_chat_app( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving parameters for a chat app.""" # Arrange @@ -91,7 +91,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_for_workflow_app( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving parameters for a workflow app.""" # Arrange @@ -136,7 +136,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_raises_error_when_chat_config_missing( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test that AppUnavailableError is raised when chat app has no config.""" # Arrange @@ -174,7 +174,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_raises_error_when_workflow_missing( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test that AppUnavailableError is raised when workflow app has no workflow.""" # Arrange @@ -234,7 +234,14 @@ class TestAppMetaApi: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.app.app.AppService") def test_get_app_meta( - self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, + mock_app_service, + mock_db, + mock_validate_token, + mock_current_app, + mock_user_logged_in, + app: Flask, + mock_app_model, ): """Test retrieving app metadata via AppService.""" # Arrange @@ -310,7 +317,7 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_app_info( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving basic app information.""" mock_current_app.login_manager = Mock() @@ -402,7 +409,9 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.current_app") @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") - def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app): + def test_get_app_info_with_no_tags( + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask + ): """Test retrieving app info when app has no tags.""" # Arrange mock_current_app.login_manager = Mock() @@ -453,7 +462,7 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_app_info_returns_correct_mode( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, app_mode ): """Test that all app modes are correctly returned.""" # Arrange diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index c16ebad739..4741481ef6 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -190,7 +191,7 @@ class TestAudioServiceMockedBehavior: class TestAudioApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) api = AudioApi() handler = _unwrap(api.post) @@ -216,7 +217,7 @@ class TestAudioApi: (InvokeError("invoke"), CompletionRequestError), ], ) - def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: + def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) api = AudioApi() handler = _unwrap(api.post) @@ -227,7 +228,7 @@ class TestAudioApi: with pytest.raises(expected): handler(api, app_model=app_model, end_user=end_user) - def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_unhandled_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")) ) @@ -242,7 +243,7 @@ class TestAudioApi: class TestTextApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) api = TextApi() @@ -259,7 +260,7 @@ class TestTextApi: assert response == {"audio": "ok"} - def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()) ) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 3364c07e62..259741937f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -295,7 +296,7 @@ class TestCompletionControllerLogic: @patch("controllers.service_api.app.completion.service_api_ns") @patch("controllers.service_api.app.completion.AppGenerateService") - def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app): + def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask): """Test CompletionApi.post success path.""" from controllers.service_api.app.completion import CompletionApi @@ -320,7 +321,7 @@ class TestCompletionControllerLogic: mock_generate_service.generate.assert_called_once() @patch("controllers.service_api.app.completion.service_api_ns") - def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app): + def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask): """Test CompletionApi.post with wrong app mode.""" from controllers.service_api.app.completion import CompletionApi @@ -334,7 +335,7 @@ class TestCompletionControllerLogic: @patch("controllers.service_api.app.completion.service_api_ns") @patch("controllers.service_api.app.completion.AppGenerateService") - def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app): + def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask): """Test ChatApi.post success path.""" from controllers.service_api.app.completion import ChatApi @@ -355,7 +356,7 @@ class TestCompletionControllerLogic: assert response == {"text": "compacted"} @patch("controllers.service_api.app.completion.service_api_ns") - def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app): + def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask): """Test ChatApi.post with wrong app mode.""" from controllers.service_api.app.completion import ChatApi @@ -368,7 +369,7 @@ class TestCompletionControllerLogic: ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user) @patch("controllers.service_api.app.completion.AppTaskService") - def test_completion_stop_api_success(self, mock_task_service, app): + def test_completion_stop_api_success(self, mock_task_service, app: Flask): """Test CompletionStopApi.post success.""" from controllers.service_api.app.completion import CompletionStopApi @@ -385,7 +386,7 @@ class TestCompletionControllerLogic: mock_task_service.stop_task.assert_called_once() @patch("controllers.service_api.app.completion.AppTaskService") - def test_chat_stop_api_success(self, mock_task_service, app): + def test_chat_stop_api_success(self, mock_task_service, app: Flask): """Test ChatStopApi.post success.""" from controllers.service_api.app.completion import ChatStopApi diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 4fb8ecf784..6dc8f54d42 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -20,6 +20,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound import services @@ -504,7 +505,7 @@ class TestConversationApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user) - def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_list_last_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _BeginStub: def __enter__(self): return SimpleNamespace() @@ -552,7 +553,7 @@ class TestConversationDetailApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "delete", @@ -570,7 +571,7 @@ class TestConversationDetailApiController: class TestConversationRenameApiController: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "rename", @@ -602,7 +603,7 @@ class TestConversationVariablesApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "get_conversational_variable", @@ -621,7 +622,7 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) monkeypatch.setattr( ConversationService, @@ -661,7 +662,7 @@ class TestConversationVariablesApiController: class TestConversationVariableDetailApiController: - def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_type_mismatch(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "update_conversation_variable", @@ -687,7 +688,7 @@ class TestConversationVariableDetailApiController: variable_id="00000000-0000-0000-0000-000000000002", ) - def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "update_conversation_variable", @@ -713,7 +714,7 @@ class TestConversationVariableDetailApiController: variable_id="00000000-0000-0000-0000-000000000002", ) - def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) monkeypatch.setattr( ConversationService, diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file.py b/api/tests/unit_tests/controllers/service_api/app/test_file.py index 7060bd79df..2615c3edac 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file.py @@ -16,6 +16,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from controllers.common.errors import ( FilenameNotExistsError, @@ -282,7 +283,7 @@ class TestFileApiPost: assert status == 201 mock_file_svc_cls.return_value.upload_file.assert_called_once() - def test_upload_no_file(self, app, mock_app_model, mock_end_user): + def test_upload_no_file(self, app: Flask, mock_app_model, mock_end_user): """Test NoFileUploadedError when no file in request.""" from controllers.service_api.app.file import FileApi @@ -296,7 +297,7 @@ class TestFileApiPost: with pytest.raises(NoFileUploadedError): _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) - def test_upload_too_many_files(self, app, mock_app_model, mock_end_user): + def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user): """Test TooManyFilesError when multiple files uploaded.""" from io import BytesIO @@ -317,7 +318,7 @@ class TestFileApiPost: with pytest.raises(TooManyFilesError): _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) - def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user): + def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user): """Test UnsupportedFileTypeError when file has no mimetype.""" from io import BytesIO diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py index 846d5368f3..510d4a9470 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -11,6 +11,7 @@ from types import SimpleNamespace from unittest.mock import ANY, MagicMock, Mock import pytest +from flask import Flask import services.app_generate_service as ags_module from controllers.service_api.app.workflow_events import WorkflowEventsApi @@ -281,7 +282,7 @@ class TestHitlServiceApi: workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open( - self, app, monkeypatch: pytest.MonkeyPatch + self, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: workflow_run = SimpleNamespace( id="run-1", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index c2b8aed1ae..2bc9771862 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.service_api.app.error import NotChatAppError @@ -390,7 +391,7 @@ class TestMessageListApi: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user) - def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_conversation_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "pagination_by_first_id", @@ -409,7 +410,7 @@ class TestMessageListApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user) - def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_first_message_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "pagination_by_first_id", @@ -430,7 +431,7 @@ class TestMessageListApi: class TestMessageFeedbackApi: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "create_feedback", @@ -452,7 +453,7 @@ class TestMessageFeedbackApi: class TestAppGetFeedbacksApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"]) api = AppGetFeedbacksApi() @@ -476,7 +477,7 @@ class TestMessageSuggestedApi: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -492,7 +493,7 @@ class TestMessageSuggestedApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_disabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -508,7 +509,7 @@ class TestMessageSuggestedApi: with pytest.raises(BadRequest): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_internal_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -524,7 +525,7 @@ class TestMessageSuggestedApi: with pytest.raises(InternalServerError): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index da09ec13ce..7115ea1e12 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -20,6 +20,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -366,7 +367,7 @@ class TestWorkflowRunRepository: class TestWorkflowRunDetailApi: - def test_not_workflow_app(self, app) -> None: + def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunDetailApi() handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -397,7 +398,7 @@ class TestWorkflowRunDetailApi: class TestWorkflowRunApi: - def test_not_workflow_app(self, app) -> None: + def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunApi() handler = _unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -407,7 +408,7 @@ class TestWorkflowRunApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user) - def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_rate_limit(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -425,7 +426,7 @@ class TestWorkflowRunApi: class TestWorkflowRunByIdApi: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -441,7 +442,7 @@ class TestWorkflowRunByIdApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, workflow_id="w1") - def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_draft_workflow(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -459,7 +460,7 @@ class TestWorkflowRunByIdApi: class TestWorkflowTaskStopApi: - def test_wrong_mode(self, app) -> None: + def test_wrong_mode(self, app: Flask) -> None: api = WorkflowTaskStopApi() handler = _unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -469,7 +470,7 @@ class TestWorkflowTaskStopApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user, task_id="t1") - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: stop_mock = Mock() send_mock = Mock() monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock) @@ -489,7 +490,7 @@ class TestWorkflowTaskStopApi: class TestWorkflowAppLogApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _BeginStub: def __enter__(self): return SimpleNamespace() @@ -557,7 +558,7 @@ class TestWorkflowRunDetailApiGet: self, mock_db, mock_repo_factory, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow run detail retrieval.""" @@ -579,7 +580,7 @@ class TestWorkflowRunDetailApiGet: assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") - def test_get_workflow_run_wrong_app_mode(self, mock_db, app): + def test_get_workflow_run_wrong_app_mode(self, mock_db, app: Flask): """Test NotWorkflowAppError when app mode is not workflow or advanced_chat.""" from controllers.service_api.app.workflow import WorkflowRunDetailApi @@ -604,7 +605,7 @@ class TestWorkflowTaskStopApiPost: self, mock_queue_mgr, mock_graph_mgr, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow task stop.""" @@ -624,7 +625,7 @@ class TestWorkflowTaskStopApiPost: mock_graph_mgr.assert_called_once() mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1") - def test_stop_workflow_task_wrong_app_mode(self, app): + def test_stop_workflow_task_wrong_app_mode(self, app: Flask): """Test NotWorkflowAppError when app mode is not workflow.""" from controllers.service_api.app.workflow import WorkflowTaskStopApi @@ -649,7 +650,7 @@ class TestWorkflowAppLogApiGet: self, mock_db, mock_wf_svc_cls, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow log retrieval.""" diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py index f45a7f9632..b3edc2ecd8 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -41,7 +42,7 @@ class TestWorkflowEventsApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user, task_id="run-1") - def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: _mock_repo_for_run(monkeypatch, workflow_run=None) api = WorkflowEventsApi() handler = _unwrap(api.get) @@ -52,7 +53,7 @@ class TestWorkflowEventsApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, task_id="run-1") - def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_workflow_run_permission_denied(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", @@ -70,7 +71,7 @@ class TestWorkflowEventsApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, task_id="run-1") - def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_finished_run_returns_sse(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", @@ -103,7 +104,7 @@ class TestWorkflowEventsApi: assert payload["task_id"] == "run-1" assert payload["event"] == "workflow_finished" - def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_running_run_streams_events(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", @@ -135,7 +136,7 @@ class TestWorkflowEventsApi: ) workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) - def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_running_run_with_snapshot(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py index f33c482d04..362af883ed 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py @@ -23,6 +23,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden, NotFound @@ -373,7 +374,7 @@ class TestDatasourcePluginsApiGet: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") - def test_get_plugins_success(self, mock_svc_cls, mock_db, app): + def test_get_plugins_success(self, mock_svc_cls, mock_db, app: Flask): """Test successful retrieval of datasource plugins.""" tenant_id = str(uuid.uuid4()) dataset_id = str(uuid.uuid4()) @@ -396,7 +397,7 @@ class TestDatasourcePluginsApiGet: ) @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_get_plugins_not_found(self, mock_db, app): + def test_get_plugins_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -407,7 +408,7 @@ class TestDatasourcePluginsApiGet: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") - def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app): + def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app: Flask): """Test empty plugin list.""" mock_db.session.scalar.return_value = Mock() mock_svc_instance = Mock() @@ -439,7 +440,7 @@ class TestDatasourceNodeRunApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app): + def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app: Flask): """Test successful datasource node run.""" tenant_id = str(uuid.uuid4()) dataset_id = str(uuid.uuid4()) @@ -473,7 +474,7 @@ class TestDatasourceNodeRunApiPost: mock_svc_instance.run_datasource_workflow_node.assert_called_once() @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_post_not_found(self, mock_db, app): + def test_post_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -488,7 +489,7 @@ class TestDatasourceNodeRunApiPost: ) @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app): + def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app: Flask): """Test AssertionError when current_user is not an Account instance.""" mock_db.session.scalar.return_value = Mock() mock_ns.payload = { @@ -549,7 +550,7 @@ class TestPipelineRunApiPost: mock_gen_svc.generate.assert_called_once() @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_post_not_found(self, mock_db, app): + def test_post_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -561,7 +562,7 @@ class TestPipelineRunApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app): + def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app: Flask): """Test Forbidden when current_user is not an Account.""" mock_db.session.scalar.return_value = Mock() mock_ns.payload = { @@ -585,7 +586,7 @@ class TestFileUploadApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app): + def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app: Flask): """Test successful file upload.""" mock_current_user.__bool__ = Mock(return_value=True) @@ -621,7 +622,7 @@ class TestFileUploadApiPost: assert response["name"] == "doc.pdf" assert response["extension"] == "pdf" - def test_upload_no_file(self, app): + def test_upload_no_file(self, app: Flask): """Test error when no file is uploaded.""" with app.test_request_context( "/datasets/pipeline/file-upload", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index e9c3e6d376..fe8fc02548 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -18,6 +18,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.dataset.segment import ( @@ -782,7 +783,7 @@ class TestSegmentApiGet: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -893,7 +894,7 @@ class TestSegmentApiPost: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -946,7 +947,7 @@ class TestSegmentApiPost: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -989,7 +990,7 @@ class TestSegmentApiPost: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1041,7 +1042,7 @@ class TestDatasetSegmentApiDelete: mock_doc_svc, mock_dataset_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1086,7 +1087,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_doc_svc, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1282,7 +1283,7 @@ class TestDatasetSegmentApiUpdate: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1322,7 +1323,7 @@ class TestDatasetSegmentApiUpdate: mock_dataset_svc, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1374,7 +1375,7 @@ class TestDatasetSegmentApiGetSingle: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1421,7 +1422,7 @@ class TestDatasetSegmentApiGetSingle: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1460,7 +1461,7 @@ class TestDatasetSegmentApiGetSingle: self, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1491,7 +1492,7 @@ class TestDatasetSegmentApiGetSingle: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1526,7 +1527,7 @@ class TestDatasetSegmentApiGetSingle: mock_dataset_svc, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1570,7 +1571,7 @@ class TestChildChunkApiGet: mock_doc_svc, mock_seg_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1609,7 +1610,7 @@ class TestChildChunkApiGet: self, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1638,7 +1639,7 @@ class TestChildChunkApiGet: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1670,7 +1671,7 @@ class TestChildChunkApiGet: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1729,7 +1730,7 @@ class TestChildChunkApiPost: mock_doc_svc, mock_seg_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1771,7 +1772,7 @@ class TestChildChunkApiPost: mock_feature_svc, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1809,7 +1810,7 @@ class TestChildChunkApiPost: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1863,7 +1864,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1913,7 +1914,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1954,7 +1955,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1994,7 +1995,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index b93a1cf14b..b7e24f9201 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -19,6 +19,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.dataset.metadata import ( @@ -76,7 +77,7 @@ class TestDatasetMetadataCreatePost: mock_dataset_svc, mock_meta_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -106,7 +107,7 @@ class TestDatasetMetadataCreatePost: def test_create_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -136,7 +137,7 @@ class TestDatasetMetadataCreateGet: self, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -160,7 +161,7 @@ class TestDatasetMetadataCreateGet: def test_get_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -201,7 +202,7 @@ class TestDatasetMetadataServiceApiPatch: mock_dataset_svc, mock_meta_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -232,7 +233,7 @@ class TestDatasetMetadataServiceApiPatch: def test_update_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -273,7 +274,7 @@ class TestDatasetMetadataServiceApiDelete: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -302,7 +303,7 @@ class TestDatasetMetadataServiceApiDelete: def test_delete_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -336,7 +337,7 @@ class TestDatasetMetadataBuiltInFieldGet: def test_get_built_in_fields_success( self, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -382,7 +383,7 @@ class TestDatasetMetadataBuiltInFieldAction: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -414,7 +415,7 @@ class TestDatasetMetadataBuiltInFieldAction: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -441,7 +442,7 @@ class TestDatasetMetadataBuiltInFieldAction: def test_action_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -485,7 +486,7 @@ class TestDocumentMetadataEditPost: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -513,7 +514,7 @@ class TestDocumentMetadataEditPost: def test_update_documents_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): diff --git a/api/tests/unit_tests/controllers/service_api/test_index.py b/api/tests/unit_tests/controllers/service_api/test_index.py index c560a3c698..8441118181 100644 --- a/api/tests/unit_tests/controllers/service_api/test_index.py +++ b/api/tests/unit_tests/controllers/service_api/test_index.py @@ -5,6 +5,7 @@ Unit tests for Service API Index endpoint from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.service_api.index import IndexApi @@ -13,7 +14,7 @@ class TestIndexApi: """Test suite for IndexApi resource.""" @patch("controllers.service_api.index.dify_config", autospec=True) - def test_get_returns_api_info(self, mock_config, app): + def test_get_returns_api_info(self, mock_config, app: Flask): """Test that GET returns API metadata with correct structure.""" # Arrange mock_config.project.version = "1.0.0-test" @@ -32,7 +33,7 @@ class TestIndexApi: assert response["api_version"] == "v1" assert response["server_version"] == "1.0.0-test" - def test_get_response_has_required_fields(self, app): + def test_get_response_has_required_fields(self, app: Flask): """Test that response contains all required fields.""" # Arrange mock_config = MagicMock() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 6dfbdcf98e..30d7b92913 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -39,7 +39,7 @@ class TestValidateAndGetApiToken: app.config["TESTING"] = True return app - def test_missing_authorization_header(self, app): + def test_missing_authorization_header(self, app: Flask): """Test that Unauthorized is raised when Authorization header is missing.""" # Arrange with app.test_request_context("/", method="GET"): @@ -50,7 +50,7 @@ class TestValidateAndGetApiToken: validate_and_get_api_token("app") assert "Authorization header must be provided" in str(exc_info.value) - def test_invalid_auth_scheme(self, app): + def test_invalid_auth_scheme(self, app: Flask): """Test that Unauthorized is raised when auth scheme is not Bearer.""" # Arrange with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}): @@ -62,7 +62,7 @@ class TestValidateAndGetApiToken: @patch("controllers.service_api.wraps.record_token_usage") @patch("controllers.service_api.wraps.ApiTokenCache") @patch("controllers.service_api.wraps.fetch_token_with_single_flight") - def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app): + def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask): """Test that valid token returns the ApiToken object.""" # Arrange mock_api_token = Mock(spec=ApiToken) @@ -84,7 +84,7 @@ class TestValidateAndGetApiToken: @patch("controllers.service_api.wraps.record_token_usage") @patch("controllers.service_api.wraps.ApiTokenCache") @patch("controllers.service_api.wraps.fetch_token_with_single_flight") - def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app): + def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask): """Test that invalid token raises Unauthorized.""" # Arrange from werkzeug.exceptions import Unauthorized @@ -161,7 +161,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app no longer exists.""" # Arrange mock_api_token = Mock() @@ -182,7 +182,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app status is abnormal.""" # Arrange mock_api_token = Mock() @@ -205,7 +205,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app API is disabled.""" # Arrange mock_api_token = Mock() @@ -240,7 +240,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app): + def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app: Flask): """Test that request is allowed when under resource limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -264,7 +264,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app): + def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app: Flask): """Test that Forbidden is raised when at resource limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -287,7 +287,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app): + def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app: Flask): """Test that request is allowed when billing is disabled.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -320,7 +320,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app): + def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask): """Test that add_segment is rejected in SANDBOX plan.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -342,7 +342,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app): + def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask): """Test that non-add_segment operations are allowed in SANDBOX.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -376,7 +376,7 @@ class TestCloudEditionBillingRateLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit") - def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app): + def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app: Flask): """Test that request is allowed when within rate limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -406,7 +406,7 @@ class TestCloudEditionBillingRateLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit") @patch("controllers.service_api.wraps.db") - def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app): + def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app: Flask): """Test that Forbidden is raised when over rate limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -445,7 +445,7 @@ class TestValidateDatasetToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.current_app") - def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app): + def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask): """Test that valid dataset token allows access.""" # Arrange # Use standard Mock for login_manager @@ -487,7 +487,7 @@ class TestValidateDatasetToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app): + def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app: Flask): """Test that NotFound is raised when dataset doesn't exist.""" # Arrange mock_api_token = Mock() diff --git a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py index 86b461cf04..c1c1291281 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py +++ b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py @@ -13,7 +13,7 @@ class TestExternalDataFetch: app = Flask(__name__) return app - def test_fetch_success(self, app): + def test_fetch_success(self, app: Flask): with app.app_context(): fetcher = ExternalDataFetch() @@ -79,7 +79,7 @@ class TestExternalDataFetch: assert result_inputs == inputs assert result_inputs is not inputs # Should be a copy - def test_fetch_with_none_variable(self, app): + def test_fetch_with_none_variable(self, app: Flask): with app.app_context(): fetcher = ExternalDataFetch() tool = ExternalDataVariableEntity(variable="var1", type="type1", config={}) @@ -95,7 +95,7 @@ class TestExternalDataFetch: assert "var1" not in result_inputs assert result_inputs == {"in": "val"} - def test_query_external_data_tool(self, app): + def test_query_external_data_tool(self, app: Flask): fetcher = ExternalDataFetch() tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})