diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index d2b892d9aa..4ce121ba60 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -109,6 +109,8 @@ jobs: - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web + env: + NODE_OPTIONS: --max-old-space-size=4096 run: vp run lint:tss - name: Web type check diff --git a/README.md b/README.md index 778028fc76..e6f8d84931 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock ```bash cd dify cd docker -cp .env.example .env -docker compose up -d +./dify-compose up -d ``` +On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory. + After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. #### Seeking help @@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases. ### Custom configurations -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). ### Metrics Monitoring with Grafana diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index e32ba5f66c..c688a69074 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -75,14 +75,15 @@ console_ns.schema_model( def _convert_values_to_json_serializable_object(value: Segment): - if isinstance(value, FileSegment): - return value.value.model_dump() - elif isinstance(value, ArrayFileSegment): - return [i.model_dump() for i in value.value] - elif isinstance(value, SegmentGroup): - return [_convert_values_to_json_serializable_object(i) for i in value.value] - else: - return value.value + match value: + case FileSegment(): + return value.value.model_dump() + case ArrayFileSegment(): + return [i.model_dump() for i in value.value] + case SegmentGroup(): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + case _: + return value.value def _serialize_var_value(variable: WorkflowDraftVariable): diff --git a/api/libs/typing.py b/api/libs/typing.py deleted file mode 100644 index f84e9911e0..0000000000 --- a/api/libs/typing.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import TypeGuard - - -def is_str_dict(v: object) -> TypeGuard[dict[str, object]]: - return isinstance(v, dict) - - -def is_str(v: object) -> TypeGuard[str]: - return isinstance(v, str) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 01d88d247c..55b6a919d8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -90,7 +90,7 @@ class TestOAuthLogin: mock_redirect, mock_get_providers, resource, - app, + app: Flask, mock_oauth_provider, invite_token, expected_token, @@ -165,7 +165,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" @@ -218,7 +218,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_config.CONSOLE_WEB_URL = "http://localhost:3000" @@ -262,7 +262,7 @@ class TestOAuthCallback: mock_tenant_service, mock_account_service, resource, - app, + app: Flask, oauth_setup, account_status, expected_redirect, @@ -301,7 +301,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): mock_get_providers.return_value = {"github": oauth_setup["provider"]} @@ -337,7 +337,7 @@ class TestOAuthCallback: mock_get_providers, mock_config, resource, - app, + app: Flask, oauth_setup, ): """Defensive test for CLOSED account status handling in OAuth callback. @@ -466,7 +466,7 @@ class TestAccountGeneration: mock_register_service, mock_feature_service, mock_get_account, - app, + app: Flask, user_info, mock_account, allow_register, @@ -505,7 +505,7 @@ class TestAccountGeneration: mock_register_service, mock_feature_service, mock_get_account, - app, + app: Flask, ): user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com") mock_feature_service.get_system_features.return_value.is_allow_register = True @@ -530,7 +530,7 @@ class TestAccountGeneration: mock_feature_service, mock_tenant_service, mock_get_account, - app, + app: Flask, user_info, mock_account, ): diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 8d6b25b5b3..d017e8f2bd 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -47,7 +47,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - app, + app: Flask, mock_account, ): # Arrange @@ -105,7 +105,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email, mock_get_account, mock_is_ip_limit, - app, + app: Flask, mock_account, language_input, expected_language, @@ -154,7 +154,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - app, + app: Flask, ): """ Test successful verification code validation. @@ -201,7 +201,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_get_data, mock_is_rate_limit, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"} @@ -345,7 +345,7 @@ class TestForgotPasswordResetApi: mock_get_account, mock_revoke_token, mock_get_data, - app, + app: Flask, mock_account, ): """ diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 2752e6b34f..7aa4aff1cc 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -30,7 +30,7 @@ class TestPipelineTemplateListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = PipelineTemplateListApi() method = unwrap(api.get) @@ -54,7 +54,7 @@ class TestPipelineTemplateDetailApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -75,7 +75,7 @@ class TestPipelineTemplateDetailApi: assert status == 200 assert response == template - def test_get_returns_404_when_template_not_found(self, app): + def test_get_returns_404_when_template_not_found(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -94,7 +94,7 @@ class TestPipelineTemplateDetailApi: assert status == 404 assert "error" in response - def test_get_returns_404_for_customized_type_not_found(self, app): + def test_get_returns_404_for_customized_type_not_found(self, app: Flask): api = PipelineTemplateDetailApi() method = unwrap(api.get) @@ -119,7 +119,7 @@ class TestCustomizedPipelineTemplateApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.patch) @@ -141,7 +141,7 @@ class TestCustomizedPipelineTemplateApi: update_mock.assert_called_once() assert response == 200 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.delete) @@ -156,7 +156,7 @@ class TestCustomizedPipelineTemplateApi: delete_mock.assert_called_once_with("tpl-1") assert response == 200 - def test_post_success(self, app, db_session_with_containers: Session): + def test_post_success(self, app: Flask, db_session_with_containers: Session): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) @@ -183,7 +183,7 @@ class TestCustomizedPipelineTemplateApi: assert status == 200 assert response == {"data": "yaml-data"} - def test_post_template_not_found(self, app): + def test_post_template_not_found(self, app: Flask): api = CustomizedPipelineTemplateApi() method = unwrap(api.post) @@ -197,7 +197,7 @@ class TestPublishCustomizedPipelineTemplateApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = PublishCustomizedPipelineTemplateApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index f238ca13ee..44eb5c336c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -36,7 +36,7 @@ class TestRagPipelineImportApi: "name": "Test", } - def test_post_success_200(self, app): + def test_post_success_200(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -66,7 +66,7 @@ class TestRagPipelineImportApi: assert status == 200 assert response == {"status": "success"} - def test_post_failed_400(self, app): + def test_post_failed_400(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -96,7 +96,7 @@ class TestRagPipelineImportApi: assert status == 400 assert response == {"status": "failed"} - def test_post_pending_202(self, app): + def test_post_pending_202(self, app: Flask): api = RagPipelineImportApi() method = unwrap(api.post) @@ -132,7 +132,7 @@ class TestRagPipelineImportConfirmApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_confirm_success(self, app): + def test_confirm_success(self, app: Flask): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -160,7 +160,7 @@ class TestRagPipelineImportConfirmApi: assert status == 200 assert response == {"ok": True} - def test_confirm_failed(self, app): + def test_confirm_failed(self, app: Flask): api = RagPipelineImportConfirmApi() method = unwrap(api.post) @@ -194,7 +194,7 @@ class TestRagPipelineImportCheckDependenciesApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = RagPipelineImportCheckDependenciesApi() method = unwrap(api.get) @@ -223,7 +223,7 @@ class TestRagPipelineExportApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_with_include_secret(self, app): + def test_get_with_include_secret(self, app: Flask): api = RagPipelineExportApi() method = unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 1fdb3057b8..c17a83cad3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -391,7 +391,7 @@ class TestPublishedPipelineApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_publish_success(self, app, db_session_with_containers: Session): + def test_publish_success(self, app: Flask, db_session_with_containers: Session): from models.dataset import Pipeline api = PublishedRagPipelineApi() diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 50ad92afa1..b59009f7c4 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -55,7 +55,7 @@ class TestDataSourceApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceApi() method = unwrap(api.get) @@ -79,7 +79,7 @@ class TestDataSourceApi: assert status == 200 assert response["data"][0]["is_bound"] is True - def test_get_no_bindings(self, app, patch_tenant): + def test_get_no_bindings(self, app: Flask, patch_tenant): api = DataSourceApi() method = unwrap(api.get) @@ -95,7 +95,7 @@ class TestDataSourceApi: assert status == 200 assert response["data"] == [] - def test_patch_enable_binding(self, app, patch_tenant, mock_engine): + def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -116,7 +116,7 @@ class TestDataSourceApi: assert status == 200 assert binding.disabled is False - def test_patch_disable_binding(self, app, patch_tenant, mock_engine): + def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -137,7 +137,7 @@ class TestDataSourceApi: assert status == 200 assert binding.disabled is True - def test_patch_binding_not_found(self, app, patch_tenant, mock_engine): + def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -152,7 +152,7 @@ class TestDataSourceApi: with pytest.raises(NotFound): method(api, "b1", "enable") - def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine): + def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -169,7 +169,7 @@ class TestDataSourceApi: with pytest.raises(ValueError): method(api, "b1", "enable") - def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): + def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine): api = DataSourceApi() method = unwrap(api.patch) @@ -192,7 +192,7 @@ class TestDataSourceNotionListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_credential_not_found(self, app, patch_tenant): + def test_get_credential_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -206,7 +206,7 @@ class TestDataSourceNotionListApi: with pytest.raises(NotFound): method(api) - def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine): + def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -247,7 +247,7 @@ class TestDataSourceNotionListApi: assert status == 200 - def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine): + def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -300,7 +300,7 @@ class TestDataSourceNotionListApi: assert status == 200 - def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine): + def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine): api = DataSourceNotionListApi() method = unwrap(api.get) @@ -327,7 +327,7 @@ class TestDataSourceNotionApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_preview_success(self, app, patch_tenant): + def test_get_preview_success(self, app: Flask, patch_tenant): api = DataSourceNotionApi() method = unwrap(api.get) @@ -348,7 +348,7 @@ class TestDataSourceNotionApi: assert status == 200 - def test_post_indexing_estimate_success(self, app, patch_tenant): + def test_post_indexing_estimate_success(self, app: Flask, patch_tenant): api = DataSourceNotionApi() method = unwrap(api.post) @@ -385,7 +385,7 @@ class TestDataSourceNotionDatasetSyncApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceNotionDatasetSyncApi() method = unwrap(api.get) @@ -408,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi: assert status == 200 - def test_get_dataset_not_found(self, app, patch_tenant): + def test_get_dataset_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionDatasetSyncApi() method = unwrap(api.get) @@ -428,7 +428,7 @@ class TestDataSourceNotionDocumentSyncApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, patch_tenant): + def test_get_success(self, app: Flask, patch_tenant): api = DataSourceNotionDocumentSyncApi() method = unwrap(api.get) @@ -451,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi: assert status == 200 - def test_get_document_not_found(self, app, patch_tenant): + def test_get_document_not_found(self, app: Flask, patch_tenant): api = DataSourceNotionDocumentSyncApi() method = unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 0b53ca5585..917aa35fe6 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -57,7 +57,7 @@ class TestConversationListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app, chat_app, user): + def test_get_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -82,7 +82,7 @@ class TestConversationListApi: assert result["has_more"] is False assert len(result["data"]) == 2 - def test_last_conversation_not_exists(self, app, chat_app, user): + def test_last_conversation_not_exists(self, app: Flask, chat_app, user): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -98,7 +98,7 @@ class TestConversationListApi: with pytest.raises(NotFound): method(chat_app) - def test_wrong_app_mode(self, app, non_chat_app): + def test_wrong_app_mode(self, app: Flask, non_chat_app): api = conversation_module.ConversationListApi() method = unwrap(api.get) @@ -112,7 +112,7 @@ class TestConversationApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_delete_success(self, app, chat_app, user): + def test_delete_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -130,7 +130,7 @@ class TestConversationApi: assert status == 204 assert body["result"] == "success" - def test_delete_not_found(self, app, chat_app, user): + def test_delete_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -146,7 +146,7 @@ class TestConversationApi: with pytest.raises(NotFound): method(chat_app, "cid") - def test_delete_wrong_app_mode(self, app, non_chat_app): + def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): api = conversation_module.ConversationApi() method = unwrap(api.delete) @@ -160,7 +160,7 @@ class TestConversationRenameApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_rename_success(self, app, chat_app, user): + def test_rename_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -179,7 +179,7 @@ class TestConversationRenameApi: assert result["id"] == "cid" - def test_rename_not_found(self, app, chat_app, user): + def test_rename_not_found(self, app: Flask, chat_app, user): api = conversation_module.ConversationRenameApi() method = unwrap(api.post) @@ -201,7 +201,7 @@ class TestConversationPinApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_pin_success(self, app, chat_app, user): + def test_pin_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationPinApi() method = unwrap(api.patch) @@ -223,7 +223,7 @@ class TestConversationUnPinApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_unpin_success(self, app, chat_app, user): + def test_unpin_success(self, app: Flask, chat_app, user): api = conversation_module.ConversationUnPinApi() method = unwrap(api.patch) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index 6efdaf2943..e41adccf3c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -49,7 +49,7 @@ class TestTriggerProviderApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_icon_success(self, app): + def test_icon_success(self, app: Flask): api = TriggerProviderIconApi() method = unwrap(api.get) @@ -63,7 +63,7 @@ class TestTriggerProviderApis: ): assert method(api, "github") == "icon" - def test_list_providers(self, app): + def test_list_providers(self, app: Flask): api = TriggerProviderListApi() method = unwrap(api.get) @@ -77,7 +77,7 @@ class TestTriggerProviderApis: ): assert method(api) == [] - def test_provider_info(self, app): + def test_provider_info(self, app: Flask): api = TriggerProviderInfoApi() method = unwrap(api.get) @@ -97,7 +97,7 @@ class TestTriggerSubscriptionListApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = TriggerSubscriptionListApi() method = unwrap(api.get) @@ -111,7 +111,7 @@ class TestTriggerSubscriptionListApi: ): assert method(api, "github") == [] - def test_list_invalid_provider(self, app): + def test_list_invalid_provider(self, app: Flask): api = TriggerSubscriptionListApi() method = unwrap(api.get) @@ -132,7 +132,7 @@ class TestTriggerSubscriptionBuilderApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create_builder(self, app): + def test_create_builder(self, app: Flask): api = TriggerSubscriptionBuilderCreateApi() method = unwrap(api.post) @@ -147,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis: result = method(api, "github") assert "subscription_builder" in result - def test_get_builder(self, app): + def test_get_builder(self, app: Flask): api = TriggerSubscriptionBuilderGetApi() method = unwrap(api.get) @@ -160,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_verify_builder(self, app): + def test_verify_builder(self, app: Flask): api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) @@ -174,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"ok": True} - def test_verify_builder_error(self, app): + def test_verify_builder_error(self, app: Flask): api = TriggerSubscriptionBuilderVerifyApi() method = unwrap(api.post) @@ -189,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis: with pytest.raises(ValueError): method(api, "github", "b1") - def test_update_builder(self, app): + def test_update_builder(self, app: Flask): api = TriggerSubscriptionBuilderUpdateApi() method = unwrap(api.post) @@ -203,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert method(api, "github", "b1") == {"id": "b1"} - def test_logs(self, app): + def test_logs(self, app: Flask): api = TriggerSubscriptionBuilderLogsApi() method = unwrap(api.get) @@ -220,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis: ): assert "logs" in method(api, "github", "b1") - def test_build(self, app): + def test_build(self, app: Flask): api = TriggerSubscriptionBuilderBuildApi() method = unwrap(api.post) @@ -240,7 +240,7 @@ class TestTriggerSubscriptionCrud: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_update_rename_only(self, app): + def test_update_rename_only(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -259,7 +259,7 @@ class TestTriggerSubscriptionCrud: ): assert method(api, "s1") == 200 - def test_update_not_found(self, app): + def test_update_not_found(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -274,7 +274,7 @@ class TestTriggerSubscriptionCrud: with pytest.raises(NotFoundError): method(api, "x") - def test_update_rebuild(self, app): + def test_update_rebuild(self, app: Flask): api = TriggerSubscriptionUpdateApi() method = unwrap(api.post) @@ -297,7 +297,7 @@ class TestTriggerSubscriptionCrud: ): assert method(api, "s1") == 200 - def test_delete_subscription(self, app): + def test_delete_subscription(self, app: Flask): api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -320,7 +320,7 @@ class TestTriggerSubscriptionCrud: assert result["result"] == "success" - def test_delete_subscription_value_error(self, app): + def test_delete_subscription_value_error(self, app: Flask): api = TriggerSubscriptionDeleteApi() method = unwrap(api.post) @@ -346,7 +346,7 @@ class TestTriggerOAuthApis: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_authorize_success(self, app): + def test_oauth_authorize_success(self, app: Flask): api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) @@ -373,7 +373,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 200 - def test_oauth_authorize_no_client(self, app): + def test_oauth_authorize_no_client(self, app: Flask): api = TriggerOAuthAuthorizeApi() method = unwrap(api.get) @@ -388,7 +388,7 @@ class TestTriggerOAuthApis: with pytest.raises(NotFoundError): method(api, "github") - def test_oauth_callback_forbidden(self, app): + def test_oauth_callback_forbidden(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -396,7 +396,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_success(self, app): + def test_oauth_callback_success(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -426,7 +426,7 @@ class TestTriggerOAuthApis: resp = method(api, "github") assert resp.status_code == 302 - def test_oauth_callback_no_oauth_client(self, app): + def test_oauth_callback_no_oauth_client(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -450,7 +450,7 @@ class TestTriggerOAuthApis: with pytest.raises(Forbidden): method(api, "github") - def test_oauth_callback_empty_credentials(self, app): + def test_oauth_callback_empty_credentials(self, app: Flask): api = TriggerOAuthCallbackApi() method = unwrap(api.get) @@ -484,7 +484,7 @@ class TestTriggerOAuthClientManageApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_client(self, app): + def test_get_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.get) @@ -511,7 +511,7 @@ class TestTriggerOAuthClientManageApi: result = method(api, "github") assert "configured" in result - def test_post_client(self, app): + def test_post_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.post) @@ -525,7 +525,7 @@ class TestTriggerOAuthClientManageApi: ): assert method(api, "github") == {"ok": True} - def test_delete_client(self, app): + def test_delete_client(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.delete) @@ -539,7 +539,7 @@ class TestTriggerOAuthClientManageApi: ): assert method(api, "github") == {"ok": True} - def test_oauth_client_post_value_error(self, app): + def test_oauth_client_post_value_error(self, app: Flask): api = TriggerOAuthClientManageApi() method = unwrap(api.post) @@ -560,7 +560,7 @@ class TestTriggerSubscriptionVerifyApi: def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_verify_success(self, app): + def test_verify_success(self, app: Flask): api = TriggerSubscriptionVerifyApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 5791d2f6e2..b73d28e4c4 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -291,7 +291,7 @@ class TestDatasetListApiGet: mock_current_user, mock_provider_mgr, mock_marshal, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -326,7 +326,7 @@ class TestDatasetListApiPost: mock_dataset_svc, mock_current_user, mock_marshal, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -352,7 +352,7 @@ class TestDatasetListApiPost: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_tenant, ): from controllers.service_api.dataset.dataset import DatasetListApi @@ -390,7 +390,7 @@ class TestDatasetApiGet: mock_provider_mgr, mock_marshal, mock_perm_svc, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -440,7 +440,7 @@ class TestDatasetApiGet: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -468,7 +468,7 @@ class TestDatasetApiDelete: mock_dataset_svc, mock_current_user, mock_perm_svc, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -490,7 +490,7 @@ class TestDatasetApiDelete: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -511,7 +511,7 @@ class TestDatasetApiDelete: self, mock_dataset_svc, mock_current_user, - app, + app: Flask, mock_dataset, ): from controllers.service_api.dataset.dataset import DatasetApi @@ -543,7 +543,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -574,7 +574,7 @@ class TestDocumentStatusApiPatch: def test_batch_update_status_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -603,7 +603,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -636,7 +636,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -669,7 +669,7 @@ class TestDocumentStatusApiPatch: mock_dataset_svc, mock_current_user, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -709,7 +709,7 @@ class TestDatasetTagsApiGet: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -731,7 +731,7 @@ class TestDatasetTagsApiGet: def test_list_tags_from_db( self, mock_current_user, - app, + app: Flask, db_session_with_containers: Session, ): """Integration test: creates real Tag rows and retrieves them @@ -774,7 +774,7 @@ class TestDatasetTagsApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -797,7 +797,7 @@ class TestDatasetTagsApiPost: mock_tag_svc.save_tags.assert_called_once() @patch("controllers.service_api.dataset.dataset.current_user") - def test_create_tag_forbidden(self, mock_current_user, app): + def test_create_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -826,7 +826,7 @@ class TestDatasetTagsApiPatch: mock_current_user, mock_service_api_ns, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -852,7 +852,7 @@ class TestDatasetTagsApiPatch: mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1") @patch("controllers.service_api.dataset.dataset.current_user") - def test_update_tag_forbidden(self, mock_current_user, app): + def test_update_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi mock_current_user.__class__ = Account @@ -880,7 +880,7 @@ class TestDatasetTagsApiDelete: mock_current_user, mock_service_api_ns, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsApi @@ -905,7 +905,7 @@ class TestDatasetTagsApiDelete: mock_tag_svc.delete_tag.assert_called_once_with("tag-1") @patch("libs.login.current_user") - def test_delete_tag_forbidden(self, mock_current_user, app): + def test_delete_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagsApi user_obj = Mock(spec=Account) @@ -933,7 +933,7 @@ class TestDatasetTagsBindingStatusApi: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi @@ -963,7 +963,7 @@ class TestDatasetTagBindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagBindingApi @@ -988,7 +988,7 @@ class TestDatasetTagBindingApiPost: ) @patch("controllers.service_api.dataset.dataset.current_user") - def test_bind_tags_forbidden(self, mock_current_user, app): + def test_bind_tags_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagBindingApi mock_current_user.__class__ = Account @@ -1014,7 +1014,7 @@ class TestDatasetTagUnbindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi @@ -1044,7 +1044,7 @@ class TestDatasetTagUnbindingApiPost: self, mock_current_user, mock_tag_svc, - app, + app: Flask, ): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi @@ -1069,7 +1069,7 @@ class TestDatasetTagUnbindingApiPost: ) @patch("controllers.service_api.dataset.dataset.current_user") - def test_unbind_tag_forbidden(self, mock_current_user, app): + def test_unbind_tag_forbidden(self, mock_current_user, app: Flask): from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi mock_current_user.__class__ = Account diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py index de9e691434..0a4e495f36 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py @@ -240,7 +240,7 @@ class TestDecodeJwtToken: mock_access_mode: MagicMock, mock_validate_token: MagicMock, mock_validate_user: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers) @@ -300,7 +300,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False) @@ -325,7 +325,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, _ = self._create_app_site_enduser(db_session_with_containers) @@ -351,7 +351,7 @@ class TestDecodeJwtToken: mock_extract: MagicMock, mock_passport_cls: MagicMock, mock_features: MagicMock, - app, + app: Flask, db_session_with_containers: Session, ) -> None: app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index 54ee133bfe..d1af0a56ef 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -21,6 +21,8 @@ from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus +TenantAndAccount = tuple[Tenant, Account] + @dataclass class TestTask: @@ -74,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration: return tenant, account @pytest.fixture - def test_queue(self, test_tenant_and_account): + def test_queue(self, test_tenant_and_account: TenantAndAccount): """Create a generic test queue for testing.""" tenant, _ = test_tenant_and_account return TenantIsolatedTaskQueue(tenant.id, "test_queue") @pytest.fixture - def secondary_queue(self, test_tenant_and_account): + def secondary_queue(self, test_tenant_and_account: TenantAndAccount): """Create a secondary test queue for testing isolation.""" tenant, _ = test_tenant_and_account return TenantIsolatedTaskQueue(tenant.id, "secondary_queue") - def test_queue_initialization(self, test_tenant_and_account): + def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount): """Test queue initialization with correct key generation.""" tenant, _ = test_tenant_and_account queue = TenantIsolatedTaskQueue(tenant.id, "test-key") @@ -95,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" assert queue._task_key == f"tenant_test-key_task:{tenant.id}" - def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker): + def test_tenant_isolation( + self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker + ): """Test that different tenants have isolated queues.""" tenant1, _ = test_tenant_and_account @@ -115,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}" assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}" - def test_key_isolation(self, test_tenant_and_account): + def test_key_isolation(self, test_tenant_and_account: TenantAndAccount): """Test that different keys have isolated queues.""" tenant, _ = test_tenant_and_account queue1 = TenantIsolatedTaskQueue(tenant.id, "key1") @@ -293,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert isinstance(task, dict) assert task["index"] == i # FIFO order - def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker): + def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """Test concurrent operations on different queues.""" tenant, _ = test_tenant_and_account @@ -436,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return tenant, account - def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker): + def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test compatibility with legacy queues containing only string data. @@ -466,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility: expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] assert pulled_tasks == expected_order - def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker): + def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test complete migration scenario from legacy to new system. @@ -547,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility: assert task["tenant_id"] == tenant.id assert task["processing_type"] == "new_system" - def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker): + def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker): """ Test error recovery when legacy queue contains malformed data. diff --git a/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py new file mode 100644 index 0000000000..49d06986fd --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_permission_service.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from uuid import uuid4 + +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from models.account import TenantPluginPermission +from services.plugin.plugin_permission_service import PluginPermissionService + + +def _tenant_id() -> str: + return str(uuid4()) + + +def _get_permission(session: Session, tenant_id: str) -> TenantPluginPermission | None: + session.expire_all() + stmt = select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id) + return session.scalars(stmt).one_or_none() + + +def _count_permissions(session: Session, tenant_id: str) -> int: + stmt = select(func.count()).select_from(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id) + return session.scalar(stmt) or 0 + + +class TestGetPermission: + """Integration tests for PluginPermissionService.get_permission using testcontainers.""" + + def test_returns_permission_when_found(self, db_session_with_containers: Session): + tenant_id = _tenant_id() + permission = TenantPluginPermission( + tenant_id=tenant_id, + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + db_session_with_containers.add(permission) + db_session_with_containers.commit() + + result = PluginPermissionService.get_permission(tenant_id) + + assert result is not None + assert result.id == permission.id + assert result.tenant_id == tenant_id + assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS + assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE + + def test_returns_none_when_not_found(self, db_session_with_containers: Session): + result = PluginPermissionService.get_permission(_tenant_id()) + + assert result is None + + +class TestChangePermission: + """Integration tests for PluginPermissionService.change_permission using testcontainers.""" + + def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session): + tenant_id = _tenant_id() + + result = PluginPermissionService.change_permission( + tenant_id, + TenantPluginPermission.InstallPermission.EVERYONE, + TenantPluginPermission.DebugPermission.EVERYONE, + ) + + permission = _get_permission(db_session_with_containers, tenant_id) + assert result is True + assert permission is not None + assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE + assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE + + def test_updates_existing_permission(self, db_session_with_containers: Session): + tenant_id = _tenant_id() + existing = TenantPluginPermission( + tenant_id=tenant_id, + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + db_session_with_containers.add(existing) + db_session_with_containers.commit() + + result = PluginPermissionService.change_permission( + tenant_id, + TenantPluginPermission.InstallPermission.ADMINS, + TenantPluginPermission.DebugPermission.ADMINS, + ) + + permission = _get_permission(db_session_with_containers, tenant_id) + assert result is True + assert permission is not None + assert permission.id == existing.id + assert permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS + assert permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS + assert _count_permissions(db_session_with_containers, tenant_id) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 3229693fd4..e2fe6c8476 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -7,6 +7,7 @@ from faker import Faker from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom +from models import App from models.model import EndUser from models.workflow import Workflow from services.app_generate_service import AppGenerateService @@ -184,7 +185,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers: Session, app): + def _create_test_workflow(self, db_session_with_containers: Session, app: App): """ Helper method to create a test workflow for testing. diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index cd63d3ad6c..1a1efe0337 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration: return app - def _create_conversation(self, db_session_with_containers: Session, app): + def _create_conversation(self, db_session_with_containers: Session, app: App): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 70aa813142..7b9e9924cd 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -4,6 +4,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session +from models import App, CreatorUserRole from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage @@ -88,7 +89,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers: Session, app): + def _create_test_end_user(self, db_session_with_containers: Session, app: App): """ Helper method to create a test end user for testing. @@ -116,7 +117,7 @@ class TestSavedMessageService: return end_user - def _create_test_message(self, db_session_with_containers: Session, app, user): + def _create_test_message(self, db_session_with_containers: Session, app: App, user): """ Helper method to create a test message for testing. @@ -199,13 +200,13 @@ class TestSavedMessageService: saved_message1 = SavedMessage( app_id=app.id, message_id=message1.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) saved_message2 = SavedMessage( app_id=app.id, message_id=message2.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -272,13 +273,13 @@ class TestSavedMessageService: saved_message1 = SavedMessage( app_id=app.id, message_id=message1.id, - created_by_role="end_user", + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) saved_message2 = SavedMessage( app_id=app.id, message_id=message2.id, - created_by_role="end_user", + created_by_role=CreatorUserRole.END_USER, created_by=end_user.id, ) @@ -449,7 +450,7 @@ class TestSavedMessageService: saved_message = SavedMessage( app_id=app.id, message_id=message.id, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) @@ -540,7 +541,9 @@ class TestSavedMessageService: message = self._create_test_message(db_session_with_containers, app, account) # Pre-create a saved message - saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id) + saved = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id + ) db_session_with_containers.add(saved) db_session_with_containers.commit() @@ -571,7 +574,9 @@ class TestSavedMessageService: end_user = self._create_test_end_user(db_session_with_containers, app) message = self._create_test_message(db_session_with_containers, app, end_user) - saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id) + saved = SavedMessage( + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id + ) db_session_with_containers.add(saved) db_session_with_containers.commit() @@ -596,10 +601,10 @@ class TestSavedMessageService: # Both users save the same message saved_account = SavedMessage( - app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account1.id ) saved_end_user = SavedMessage( - app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id + app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id ) db_session_with_containers.add_all([saved_account, saved_end_user]) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index f2307fbd7d..797731d04b 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account +from models import Account, App from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation @@ -93,7 +93,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers: Session, app): + def _create_test_end_user(self, db_session_with_containers: Session, app: App): """ Helper method to create a test end user for testing. diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index 78413a0798..0fb0ebc330 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -185,7 +185,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): @@ -263,7 +263,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): @@ -312,7 +312,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, language, @@ -358,7 +358,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, ): """ @@ -398,7 +398,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, ): """ @@ -438,7 +438,7 @@ class TestActivateApi: mock_db, mock_revoke_token, mock_get_invitation, - app, + app: Flask, mock_invitation, mock_account, ): diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index 7b2c7569fe..102af9b250 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -195,7 +195,7 @@ class TestEmailCodeLoginSendEmailApi: mock_get_user, mock_is_ip_limit, mock_db, - app, + app: Flask, mock_account, language_input, expected_language, @@ -267,7 +267,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -315,7 +315,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -431,7 +431,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -474,7 +474,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ @@ -515,7 +515,7 @@ class TestEmailCodeLoginApi: mock_revoke_token, mock_get_data, mock_db, - app, + app: Flask, mock_account, ): """ diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 5284f29eed..ace2ce5706 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -412,7 +412,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -448,7 +448,7 @@ class TestLoginApi: mock_revoke_token, mock_get_token_data, mock_db, - app, + app: Flask, ): mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"} mock_get_account.side_effect = Unauthorized("Account is banned.") diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index 15c95f6b94..22974ca416 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -74,7 +74,7 @@ class TestRefreshTokenApi: assert response.json["result"] == "success" @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) - def test_refresh_fails_without_token(self, mock_extract_token, app): + def test_refresh_fails_without_token(self, mock_extract_token, app: Flask): """ Test token refresh failure when no refresh token provided. @@ -98,7 +98,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh failure with invalid refresh token. @@ -123,7 +123,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh failure with expired refresh token. @@ -148,7 +148,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app): + def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app: Flask): """ Test token refresh with empty string token. diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 5136922e88..9c5b5ec256 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -29,7 +30,7 @@ def unwrap(func): class TestDatasourcePluginOAuthAuthorizationUrl: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -61,7 +62,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: assert response.status_code == 200 - def test_get_no_oauth_config(self, app): + def test_get_no_oauth_config(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -80,7 +81,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: with pytest.raises(ValueError): method(api, "notion") - def test_get_without_credential_id_sets_cookie(self, app): + def test_get_without_credential_id_sets_cookie(self, app: Flask): api = DatasourcePluginOAuthAuthorizationUrl() method = unwrap(api.get) @@ -115,7 +116,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl: class TestDatasourceOAuthCallback: - def test_callback_success_new_credential(self, app): + def test_callback_success_new_credential(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -157,7 +158,7 @@ class TestDatasourceOAuthCallback: assert response.status_code == 302 - def test_callback_missing_context(self, app): + def test_callback_missing_context(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -165,7 +166,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(Forbidden): method(api, "notion") - def test_callback_invalid_context(self, app): + def test_callback_invalid_context(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -180,7 +181,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(Forbidden): method(api, "notion") - def test_callback_oauth_config_not_found(self, app): + def test_callback_oauth_config_not_found(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -202,7 +203,7 @@ class TestDatasourceOAuthCallback: with pytest.raises(NotFound): method(api, "notion") - def test_callback_reauthorize_existing_credential(self, app): + def test_callback_reauthorize_existing_credential(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -245,7 +246,7 @@ class TestDatasourceOAuthCallback: assert response.status_code == 302 assert "/oauth-callback" in response.location - def test_callback_context_id_from_cookie(self, app): + def test_callback_context_id_from_cookie(self, app: Flask): api = DatasourceOAuthCallback() method = unwrap(api.get) @@ -289,7 +290,7 @@ class TestDatasourceOAuthCallback: class TestDatasourceAuth: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -312,7 +313,7 @@ class TestDatasourceAuth: assert status == 200 - def test_post_invalid_credentials(self, app): + def test_post_invalid_credentials(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -334,7 +335,7 @@ class TestDatasourceAuth: with pytest.raises(ValueError): method(api, "notion") - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasourceAuth() method = unwrap(api.get) @@ -355,7 +356,7 @@ class TestDatasourceAuth: assert status == 200 assert response["result"] - def test_post_missing_credentials(self, app): + def test_post_missing_credentials(self, app: Flask): api = DatasourceAuth() method = unwrap(api.post) @@ -372,7 +373,7 @@ class TestDatasourceAuth: with pytest.raises(ValueError): method(api, "notion") - def test_get_empty_list(self, app): + def test_get_empty_list(self, app: Flask): api = DatasourceAuth() method = unwrap(api.get) @@ -395,7 +396,7 @@ class TestDatasourceAuth: class TestDatasourceAuthDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasourceAuthDeleteApi() method = unwrap(api.post) @@ -418,7 +419,7 @@ class TestDatasourceAuthDeleteApi: assert status == 200 - def test_delete_missing_credential_id(self, app): + def test_delete_missing_credential_id(self, app: Flask): api = DatasourceAuthDeleteApi() method = unwrap(api.post) @@ -437,7 +438,7 @@ class TestDatasourceAuthDeleteApi: class TestDatasourceAuthUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestDatasourceAuthUpdateApi: assert status == 201 - def test_update_with_credentials_none(self, app): + def test_update_with_credentials_none(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -484,7 +485,7 @@ class TestDatasourceAuthUpdateApi: update_mock.assert_called_once() assert status == 201 - def test_update_name_only(self, app): + def test_update_name_only(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -507,7 +508,7 @@ class TestDatasourceAuthUpdateApi: assert status == 201 - def test_update_with_empty_credentials_dict(self, app): + def test_update_with_empty_credentials_dict(self, app: Flask): api = DatasourceAuthUpdateApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestDatasourceAuthUpdateApi: class TestDatasourceAuthListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = DatasourceAuthListApi() method = unwrap(api.get) @@ -553,7 +554,7 @@ class TestDatasourceAuthListApi: assert status == 200 - def test_auth_list_empty(self, app): + def test_auth_list_empty(self, app: Flask): api = DatasourceAuthListApi() method = unwrap(api.get) @@ -574,7 +575,7 @@ class TestDatasourceAuthListApi: assert status == 200 assert response["result"] == [] - def test_hardcode_list_empty(self, app): + def test_hardcode_list_empty(self, app: Flask): api = DatasourceHardCodeAuthListApi() method = unwrap(api.get) @@ -597,7 +598,7 @@ class TestDatasourceAuthListApi: class TestDatasourceHardCodeAuthListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = DatasourceHardCodeAuthListApi() method = unwrap(api.get) @@ -619,7 +620,7 @@ class TestDatasourceHardCodeAuthListApi: class TestDatasourceAuthOauthCustomClient: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -642,7 +643,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.delete) @@ -662,7 +663,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_post_empty_payload(self, app): + def test_post_empty_payload(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -685,7 +686,7 @@ class TestDatasourceAuthOauthCustomClient: assert status == 200 - def test_post_disabled_flag(self, app): + def test_post_disabled_flag(self, app: Flask): api = DatasourceAuthOauthCustomClient() method = unwrap(api.post) @@ -714,7 +715,7 @@ class TestDatasourceAuthOauthCustomClient: class TestDatasourceAuthDefaultApi: - def test_set_default_success(self, app): + def test_set_default_success(self, app: Flask): api = DatasourceAuthDefaultApi() method = unwrap(api.post) @@ -737,7 +738,7 @@ class TestDatasourceAuthDefaultApi: assert status == 200 - def test_default_missing_id(self, app): + def test_default_missing_id(self, app: Flask): api = DatasourceAuthDefaultApi() method = unwrap(api.post) @@ -756,7 +757,7 @@ class TestDatasourceAuthDefaultApi: class TestDatasourceUpdateProviderNameApi: - def test_update_name_success(self, app): + def test_update_name_success(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) @@ -779,7 +780,7 @@ class TestDatasourceUpdateProviderNameApi: assert status == 200 - def test_update_name_too_long(self, app): + def test_update_name_too_long(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) @@ -799,7 +800,7 @@ class TestDatasourceUpdateProviderNameApi: with pytest.raises(ValueError): method(api, "notion") - def test_update_name_missing_credential_id(self, app): + def test_update_name_missing_credential_id(self, app: Flask): api = DatasourceUpdateProviderNameApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py index 7a8ccde55a..d4c6a775ec 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console import console_ns @@ -25,7 +26,7 @@ class TestDataSourceContentPreviewApi: "credential_id": "cred-1", } - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -66,7 +67,7 @@ class TestDataSourceContentPreviewApi: assert status == 200 assert response == preview_result - def test_post_forbidden_non_account_user(self, app): + def test_post_forbidden_non_account_user(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -85,7 +86,7 @@ class TestDataSourceContentPreviewApi: with pytest.raises(Forbidden): method(api, pipeline, "node-1") - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) @@ -108,7 +109,7 @@ class TestDataSourceContentPreviewApi: with pytest.raises(ValueError): method(api, pipeline, "node-1") - def test_post_without_credential_id(self, app): + def test_post_without_credential_id(self, app: Flask): api = DataSourceContentPreviewApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index 9465936f28..e28d68ee5a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -2,6 +2,7 @@ import datetime from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden, NotFound import services @@ -58,7 +59,7 @@ class TestDatasetList: user.is_dataset_editor = True return user - def test_get_success_basic(self, app): + def test_get_success_basic(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -93,7 +94,7 @@ class TestDatasetList: assert resp["total"] == 1 assert resp["data"][0]["embedding_available"] is True - def test_get_with_ids_filter(self, app): + def test_get_with_ids_filter(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -128,7 +129,7 @@ class TestDatasetList: assert status == 200 assert resp["total"] == 2 - def test_get_with_tag_ids(self, app): + def test_get_with_tag_ids(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -161,7 +162,7 @@ class TestDatasetList: assert status == 200 - def test_embedding_available_false(self, app): + def test_embedding_available_false(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -203,7 +204,7 @@ class TestDatasetList: assert resp["data"][0]["embedding_available"] is False - def test_partial_members_permission(self, app): + def test_partial_members_permission(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) @@ -242,7 +243,7 @@ class TestDatasetList: class TestDatasetListApiPost: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -290,7 +291,7 @@ class TestDatasetListApiPost: assert status == 201 - def test_post_forbidden(self, app): + def test_post_forbidden(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -310,7 +311,7 @@ class TestDatasetListApiPost: with pytest.raises(Forbidden): method(api) - def test_post_duplicate_name(self, app): + def test_post_duplicate_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -335,7 +336,7 @@ class TestDatasetListApiPost: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload_missing_name(self, app): + def test_post_invalid_payload_missing_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -343,7 +344,7 @@ class TestDatasetListApiPost: with pytest.raises(ValueError): method(api) - def test_post_invalid_indexing_technique(self, app): + def test_post_invalid_indexing_technique(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -356,7 +357,7 @@ class TestDatasetListApiPost: with pytest.raises(ValueError, match="Invalid indexing technique"): method(api) - def test_post_invalid_provider(self, app): + def test_post_invalid_provider(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) @@ -371,7 +372,7 @@ class TestDatasetListApiPost: class TestDatasetApiGet: - def test_get_success_basic(self, app): + def test_get_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -427,7 +428,7 @@ class TestDatasetApiGet: assert status == 200 assert data["embedding_available"] is True - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -448,7 +449,7 @@ class TestDatasetApiGet: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -475,7 +476,7 @@ class TestDatasetApiGet: with pytest.raises(Forbidden, match="no access"): method(api, dataset_id) - def test_get_high_quality_embedding_unavailable(self, app): + def test_get_high_quality_embedding_unavailable(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -530,7 +531,7 @@ class TestDatasetApiGet: assert data["embedding_available"] is False - def test_get_partial_members_permission(self, app): + def test_get_partial_members_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.get) @@ -590,7 +591,7 @@ class TestDatasetApiGet: class TestDatasetApiPatch: - def test_patch_success_basic(self, app): + def test_patch_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -659,7 +660,7 @@ class TestDatasetApiPatch: assert status == 200 assert result["partial_member_list"] == [] - def test_patch_dataset_not_found(self, app): + def test_patch_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -674,7 +675,7 @@ class TestDatasetApiPatch: with pytest.raises(NotFound, match="Dataset not found"): method(api, "missing") - def test_patch_permission_denied(self, app): + def test_patch_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -704,7 +705,7 @@ class TestDatasetApiPatch: with pytest.raises(Forbidden): method(api, dataset_id) - def test_patch_partial_members_update(self, app): + def test_patch_partial_members_update(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -773,7 +774,7 @@ class TestDatasetApiPatch: assert result["partial_member_list"] == payload["partial_member_list"] - def test_patch_clear_partial_members(self, app): + def test_patch_clear_partial_members(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) @@ -843,7 +844,7 @@ class TestDatasetApiPatch: class TestDatasetApiDelete: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -874,7 +875,7 @@ class TestDatasetApiDelete: assert status == 204 assert result == {"result": "success"} - def test_delete_forbidden_no_permission(self, app): + def test_delete_forbidden_no_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -893,7 +894,7 @@ class TestDatasetApiDelete: with pytest.raises(Forbidden): method(api, dataset_id) - def test_delete_dataset_not_found(self, app): + def test_delete_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -917,7 +918,7 @@ class TestDatasetApiDelete: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_delete_dataset_in_use(self, app): + def test_delete_dataset_in_use(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) @@ -943,7 +944,7 @@ class TestDatasetApiDelete: class TestDatasetUseCheckApi: - def test_get_use_check_true(self, app): + def test_get_use_check_true(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) @@ -962,7 +963,7 @@ class TestDatasetUseCheckApi: assert status == 200 assert result == {"is_using": True} - def test_get_use_check_false(self, app): + def test_get_use_check_false(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) @@ -983,7 +984,7 @@ class TestDatasetUseCheckApi: class TestDatasetQueryApi: - def test_get_queries_success(self, app): + def test_get_queries_success(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1027,7 +1028,7 @@ class TestDatasetQueryApi: assert response["has_more"] is False assert len(response["data"]) == 2 - def test_get_queries_dataset_not_found(self, app): + def test_get_queries_dataset_not_found(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1049,7 +1050,7 @@ class TestDatasetQueryApi: with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) - def test_get_queries_permission_denied(self, app): + def test_get_queries_permission_denied(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1078,7 +1079,7 @@ class TestDatasetQueryApi: with pytest.raises(Forbidden): method(api, dataset_id) - def test_get_queries_pagination_has_more(self, app): + def test_get_queries_pagination_has_more(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) @@ -1152,7 +1153,7 @@ class TestDatasetIndexingEstimateApi: "dataset_id": None, } - def test_post_success_upload_file(self, app): + def test_post_success_upload_file(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) @@ -1193,7 +1194,7 @@ class TestDatasetIndexingEstimateApi: assert status == 200 assert response == {"tokens": 100} - def test_post_file_not_found(self, app): + def test_post_file_not_found(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) @@ -1223,7 +1224,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(NotFound): method(api) - def test_post_llm_bad_request_error(self, app): + def test_post_llm_bad_request_error(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1258,7 +1259,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(ProviderNotInitializeError): method(api) - def test_post_provider_token_not_init(self, app): + def test_post_provider_token_not_init(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1293,7 +1294,7 @@ class TestDatasetIndexingEstimateApi: with pytest.raises(ProviderNotInitializeError): method(api) - def test_post_generic_exception(self, app): + def test_post_generic_exception(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() @@ -1330,7 +1331,7 @@ class TestDatasetIndexingEstimateApi: class TestDatasetRelatedAppListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1368,7 +1369,7 @@ class TestDatasetRelatedAppListApi: assert response["total"] == 2 assert response["data"] == [app1, app2] - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1386,7 +1387,7 @@ class TestDatasetRelatedAppListApi: with pytest.raises(NotFound): method(api, "dataset-1") - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1410,7 +1411,7 @@ class TestDatasetRelatedAppListApi: with pytest.raises(Forbidden): method(api, "dataset-1") - def test_get_filters_none_apps(self, app): + def test_get_filters_none_apps(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) @@ -1449,7 +1450,7 @@ class TestDatasetRelatedAppListApi: class TestDatasetIndexingStatusApi: - def test_get_success_with_documents(self, app): + def test_get_success_with_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1490,7 +1491,7 @@ class TestDatasetIndexingStatusApi: assert item["completed_segments"] == 3 assert item["total_segments"] == 3 - def test_get_success_no_documents(self, app): + def test_get_success_no_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1510,7 +1511,7 @@ class TestDatasetIndexingStatusApi: assert status == 200 assert response == {"data": []} - def test_segment_counts_different_values(self, app): + def test_segment_counts_different_values(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) @@ -1550,7 +1551,7 @@ class TestDatasetIndexingStatusApi: class TestDatasetApiKeyApi: - def test_get_api_keys_success(self, app): + def test_get_api_keys_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.get) @@ -1587,7 +1588,7 @@ class TestDatasetApiKeyApi: assert response["data"][1]["id"] == "key-2" assert response["data"][1]["token"] == "ds-def" - def test_post_create_api_key_success(self, app): + def test_post_create_api_key_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) @@ -1632,7 +1633,7 @@ class TestDatasetApiKeyApi: assert response["type"] == "dataset" assert response["created_at"] is not None - def test_post_exceed_max_keys(self, app): + def test_post_exceed_max_keys(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) @@ -1658,7 +1659,7 @@ class TestDatasetApiKeyApi: class TestDatasetApiDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) @@ -1688,7 +1689,7 @@ class TestDatasetApiDeleteApi: assert status == 204 assert response["result"] == "success" - def test_delete_key_not_found(self, app): + def test_delete_key_not_found(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) @@ -1708,7 +1709,7 @@ class TestDatasetApiDeleteApi: class TestDatasetEnableApiApi: - def test_enable_api(self, app): + def test_enable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) @@ -1724,7 +1725,7 @@ class TestDatasetEnableApiApi: assert status == 200 assert response["result"] == "success" - def test_disable_api(self, app): + def test_disable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) @@ -1742,7 +1743,7 @@ class TestDatasetEnableApiApi: class TestDatasetApiBaseUrlApi: - def test_get_api_base_url_from_config(self, app): + def test_get_api_base_url_from_config(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1757,7 +1758,7 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "https://example.com/v1" - def test_get_api_base_url_from_request(self, app): + def test_get_api_base_url_from_request(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1772,7 +1773,7 @@ class TestDatasetApiBaseUrlApi: assert response["api_base_url"] == "http://localhost:5000/v1" - def test_get_api_base_url_no_double_v1(self, app): + def test_get_api_base_url_no_double_v1(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) @@ -1789,7 +1790,7 @@ class TestDatasetApiBaseUrlApi: class TestDatasetRetrievalSettingApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRetrievalSettingApi() method = unwrap(api.get) @@ -1810,7 +1811,7 @@ class TestDatasetRetrievalSettingApi: class TestDatasetRetrievalSettingMockApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetRetrievalSettingMockApi() method = unwrap(api.get) @@ -1827,7 +1828,7 @@ class TestDatasetRetrievalSettingMockApi: class TestDatasetErrorDocs: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) @@ -1850,7 +1851,7 @@ class TestDatasetErrorDocs: assert status == 200 assert response["total"] == 1 - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) @@ -1866,7 +1867,7 @@ class TestDatasetErrorDocs: class TestDatasetPermissionUserListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) @@ -1897,7 +1898,7 @@ class TestDatasetPermissionUserListApi: assert status == 200 assert response["data"] == users - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) @@ -1923,7 +1924,7 @@ class TestDatasetPermissionUserListApi: class TestDatasetAutoDisableLogApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) @@ -1946,7 +1947,7 @@ class TestDatasetAutoDisableLogApi: assert status == 200 assert response == logs - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index d9b02ac453..ff9e1736d2 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound import services @@ -239,7 +240,7 @@ class TestDatasetDocumentListApi: assert "documents" in response - def test_post_forbidden(self, app): + def test_post_forbidden(self, app: Flask): api = DatasetDocumentListApi() method = unwrap(api.post) @@ -395,7 +396,7 @@ class TestDocumentDownloadApi: class TestDocumentProcessingApi: - def test_processing_forbidden_when_not_editor(self, app): + def test_processing_forbidden_when_not_editor(self, app: Flask): api = DocumentProcessingApi() method = unwrap(api.patch) @@ -1185,7 +1186,7 @@ class TestDocumentPermissionCases: "preview": [], } - def test_document_tenant_mismatch(self, app): + def test_document_tenant_mismatch(self, app: Flask): api = DocumentApi() method = unwrap(api.get) @@ -1253,7 +1254,7 @@ class TestDocumentPermissionCases: assert status == 200 assert response["mode"] == "custom" - def test_process_rule_permission_denied(self, app): + def test_process_rule_permission_denied(self, app: Flask): api = GetProcessRuleApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 693b06e95b..412edb9dfe 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, NotFound import services @@ -82,7 +83,7 @@ def test_get_segment_with_summary(monkeypatch): class TestDatasetDocumentSegmentListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -132,7 +133,7 @@ class TestDatasetDocumentSegmentListApi: assert status == 200 - def test_get_dataset_not_found(self, app): + def test_get_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -150,7 +151,7 @@ class TestDatasetDocumentSegmentListApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_get_permission_denied(self, app): + def test_get_permission_denied(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -176,7 +177,7 @@ class TestDatasetDocumentSegmentListApi: class TestDatasetDocumentSegmentApi: - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -221,7 +222,7 @@ class TestDatasetDocumentSegmentApi: assert status == 200 assert response["result"] == "success" - def test_patch_document_indexing_in_progress(self, app): + def test_patch_document_indexing_in_progress(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -264,7 +265,7 @@ class TestDatasetDocumentSegmentApi: with pytest.raises(InvalidActionError): method(api, "ds-1", "doc-1", "disable") - def test_patch_llm_bad_request(self, app): + def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -308,7 +309,7 @@ class TestDatasetDocumentSegmentApi: with pytest.raises(ProviderNotInitializeError): method(api, "ds-1", "doc-1", "enable") - def test_patch_provider_token_not_init(self, app): + def test_patch_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentApi() method = unwrap(api.patch) @@ -354,7 +355,7 @@ class TestDatasetDocumentSegmentApi: class TestDatasetDocumentSegmentAddApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -413,7 +414,7 @@ class TestDatasetDocumentSegmentAddApi: assert status == 200 assert response["data"]["id"] == "seg-1" - def test_post_llm_bad_request(self, app): + def test_post_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -452,7 +453,7 @@ class TestDatasetDocumentSegmentAddApi: with pytest.raises(ProviderNotInitializeError): method(api, "ds-1", "doc-1") - def test_post_provider_token_not_init(self, app): + def test_post_provider_token_not_init(self, app: Flask): api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -493,7 +494,7 @@ class TestDatasetDocumentSegmentAddApi: class TestDatasetDocumentSegmentUpdateApi: - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() method = unwrap(api.patch) @@ -551,7 +552,7 @@ class TestDatasetDocumentSegmentUpdateApi: assert status == 200 assert "data" in response - def test_patch_llm_bad_request(self, app): + def test_patch_llm_bad_request(self, app: Flask): api = DatasetDocumentSegmentUpdateApi() method = unwrap(api.patch) @@ -596,7 +597,7 @@ class TestDatasetDocumentSegmentUpdateApi: class TestDatasetDocumentSegmentBatchImportApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -638,7 +639,7 @@ class TestDatasetDocumentSegmentBatchImportApi: assert status == 200 assert response["job_status"] == "waiting" - def test_post_dataset_not_found(self, app): + def test_post_dataset_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -659,7 +660,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_document_not_found(self, app): + def test_post_document_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -684,7 +685,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_upload_file_not_found(self, app): + def test_post_upload_file_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -713,7 +714,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_post_invalid_file_type(self, app): + def test_post_invalid_file_type(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -745,7 +746,7 @@ class TestDatasetDocumentSegmentBatchImportApi: with pytest.raises(ValueError): method(api, "ds-1", "doc-1") - def test_post_async_task_failure(self, app): + def test_post_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -783,7 +784,7 @@ class TestDatasetDocumentSegmentBatchImportApi: assert status == 500 assert "error" in response - def test_get_job_not_found_in_redis(self, app): + def test_get_job_not_found_in_redis(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.get) @@ -799,7 +800,7 @@ class TestDatasetDocumentSegmentBatchImportApi: class TestChildChunkAddApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = ChildChunkAddApi() method = unwrap(api.post) @@ -852,7 +853,7 @@ class TestChildChunkAddApi: assert status == 200 assert response["data"]["id"] == "cc-1" - def test_post_child_chunk_indexing_error(self, app): + def test_post_child_chunk_indexing_error(self, app: Flask): api = ChildChunkAddApi() method = unwrap(api.post) @@ -897,7 +898,7 @@ class TestChildChunkAddApi: class TestChildChunkUpdateApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ChildChunkUpdateApi() method = unwrap(api.delete) @@ -941,7 +942,7 @@ class TestChildChunkUpdateApi: assert status == 204 assert response["result"] == "success" - def test_delete_child_chunk_index_error(self, app): + def test_delete_child_chunk_index_error(self, app: Flask): api = ChildChunkUpdateApi() method = unwrap(api.delete) @@ -984,7 +985,7 @@ class TestChildChunkUpdateApi: class TestSegmentListAdvancedCases: - def test_segment_list_with_keyword_filter(self, app): + def test_segment_list_with_keyword_filter(self, app: Flask): api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1035,7 +1036,7 @@ class TestSegmentListAdvancedCases: assert status == 200 assert response["total"] == 1 - def test_segment_list_permission_denied(self, app): + def test_segment_list_permission_denied(self, app: Flask): """Test segment list with permission denied""" api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1058,7 +1059,7 @@ class TestSegmentListAdvancedCases: with pytest.raises(Forbidden): method(api, "ds-1", "doc-1") - def test_segment_list_dataset_not_found(self, app): + def test_segment_list_dataset_not_found(self, app: Flask): """Test segment list with dataset not found""" api = DatasetDocumentSegmentListApi() method = unwrap(api.get) @@ -1079,7 +1080,7 @@ class TestSegmentListAdvancedCases: class TestSegmentOperationCases: - def test_segment_add_with_provider_token_error(self, app): + def test_segment_add_with_provider_token_error(self, app: Flask): """Test segment add with provider token not initialized""" api = DatasetDocumentSegmentAddApi() method = unwrap(api.post) @@ -1117,7 +1118,7 @@ class TestSegmentOperationCases: with pytest.raises(ProviderTokenNotInitError): method(api, "ds-1", "doc-1") - def test_batch_import_with_document_not_found(self, app): + def test_batch_import_with_document_not_found(self, app: Flask): """Test batch import with document not found""" api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1146,7 +1147,7 @@ class TestSegmentOperationCases: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_batch_import_with_invalid_file(self, app): + def test_batch_import_with_invalid_file(self, app: Flask): """Test batch import with invalid file type""" api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1181,7 +1182,7 @@ class TestSegmentOperationCases: with pytest.raises(NotFound): method(api, "ds-1", "doc-1") - def test_batch_import_with_async_task_failure(self, app): + def test_batch_import_with_async_task_failure(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.post) @@ -1226,7 +1227,7 @@ class TestSegmentOperationCases: assert status == 500 assert "error" in response - def test_batch_import_get_job_not_found(self, app): + def test_batch_import_get_job_not_found(self, app: Flask): api = DatasetDocumentSegmentBatchImportApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 514bbbe040..7254bf7670 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -57,7 +57,7 @@ def mock_auth(monkeypatch, current_user): class TestExternalApiTemplateListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ExternalApiTemplateListApi() method = unwrap(api.get) @@ -78,7 +78,7 @@ class TestExternalApiTemplateListApi: assert resp["total"] == 1 assert resp["data"][0]["id"] == "1" - def test_post_forbidden(self, app, current_user): + def test_post_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False api = ExternalApiTemplateListApi() method = unwrap(api.post) @@ -93,7 +93,7 @@ class TestExternalApiTemplateListApi: with pytest.raises(Forbidden): method(api) - def test_post_duplicate_name(self, app): + def test_post_duplicate_name(self, app: Flask): api = ExternalApiTemplateListApi() method = unwrap(api.post) @@ -114,7 +114,7 @@ class TestExternalApiTemplateListApi: class TestExternalApiTemplateApi: - def test_get_not_found(self, app): + def test_get_not_found(self, app: Flask): api = ExternalApiTemplateApi() method = unwrap(api.get) @@ -129,7 +129,7 @@ class TestExternalApiTemplateApi: with pytest.raises(NotFound): method(api, "api-id") - def test_delete_forbidden(self, app, current_user): + def test_delete_forbidden(self, app: Flask, current_user): current_user.has_edit_permission = False current_user.is_dataset_operator = False @@ -142,7 +142,7 @@ class TestExternalApiTemplateApi: class TestExternalApiUseCheckApi: - def test_get_scopes_usage_check_to_current_tenant(self, app): + def test_get_scopes_usage_check_to_current_tenant(self, app: Flask): api = ExternalApiUseCheckApi() method = unwrap(api.get) @@ -162,7 +162,7 @@ class TestExternalApiUseCheckApi: class TestExternalDatasetCreateApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = ExternalDatasetCreateApi() method = unwrap(api.post) @@ -206,7 +206,7 @@ class TestExternalDatasetCreateApi: assert status == 201 - def test_create_forbidden(self, app, current_user): + def test_create_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False api = ExternalDatasetCreateApi() method = unwrap(api.post) @@ -226,7 +226,7 @@ class TestExternalDatasetCreateApi: class TestExternalKnowledgeHitTestingApi: - def test_hit_testing_dataset_not_found(self, app): + def test_hit_testing_dataset_not_found(self, app: Flask): api = ExternalKnowledgeHitTestingApi() method = unwrap(api.post) @@ -241,7 +241,7 @@ class TestExternalKnowledgeHitTestingApi: with pytest.raises(NotFound): method(api, "dataset-id") - def test_hit_testing_success(self, app): + def test_hit_testing_success(self, app: Flask): api = ExternalKnowledgeHitTestingApi() method = unwrap(api.post) @@ -266,7 +266,7 @@ class TestExternalKnowledgeHitTestingApi: class TestBedrockRetrievalApi: - def test_bedrock_retrieval(self, app): + def test_bedrock_retrieval(self, app: Flask): api = BedrockRetrievalApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py index de834c2d4d..0105aacd65 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -269,7 +269,7 @@ class TestDatasetMetadataApi: class TestDatasetMetadataBuiltInFieldApi: - def test_get_built_in_fields(self, app): + def test_get_built_in_fields(self, app: Flask): api = DatasetMetadataBuiltInFieldApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py index c8f674f515..d1cb6b6a03 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_banner.py +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from flask import Flask + import controllers.console.explore.banner as banner_module from models.enums import BannerStatus @@ -12,7 +14,7 @@ def unwrap(func): class TestBannerApi: - def test_get_banners_with_requested_language(self, app): + def test_get_banners_with_requested_language(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) @@ -41,7 +43,7 @@ class TestBannerApi: } ] - def test_get_banners_fallback_to_en_us(self, app): + def test_get_banners_fallback_to_en_us(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) @@ -76,7 +78,7 @@ class TestBannerApi: } ] - def test_get_banners_default_language_en_us(self, app): + def test_get_banners_default_language_en_us(self, app: Flask): api = banner_module.BannerApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 145cc9cdd7..3d41489435 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -54,7 +55,7 @@ def make_message(): class TestMessageListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -96,7 +97,7 @@ class TestMessageListApi: with pytest.raises(NotChatAppError): method(installed_app) - def test_conversation_not_exists(self, app): + def test_conversation_not_exists(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -118,7 +119,7 @@ class TestMessageListApi: with pytest.raises(NotFound): method(installed_app) - def test_first_message_not_exists(self, app): + def test_first_message_not_exists(self, app: Flask): api = module.MessageListApi() method = unwrap(api.get) @@ -142,7 +143,7 @@ class TestMessageListApi: class TestMessageFeedbackApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = module.MessageFeedbackApi() method = unwrap(api.post) @@ -161,7 +162,7 @@ class TestMessageFeedbackApi: assert result["result"] == "success" - def test_message_not_exists(self, app): + def test_message_not_exists(self, app: Flask): api = module.MessageFeedbackApi() method = unwrap(api.post) @@ -182,7 +183,7 @@ class TestMessageFeedbackApi: class TestMessageMoreLikeThisApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -221,7 +222,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(NotCompletionAppError): method(installed_app, "mid") - def test_more_like_this_disabled(self, app): + def test_more_like_this_disabled(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -243,7 +244,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(AppMoreLikeThisDisabledError): method(installed_app, "mid") - def test_message_not_exists_more_like_this(self, app): + def test_message_not_exists_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -265,7 +266,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(NotFound): method(installed_app, "mid") - def test_provider_not_init_more_like_this(self, app): + def test_provider_not_init_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -287,7 +288,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderNotInitializeError): method(installed_app, "mid") - def test_quota_exceeded_more_like_this(self, app): + def test_quota_exceeded_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -309,7 +310,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderQuotaExceededError): method(installed_app, "mid") - def test_model_not_support_more_like_this(self, app): + def test_model_not_support_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -331,7 +332,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(ProviderModelCurrentlyNotSupportError): method(installed_app, "mid") - def test_invoke_error_more_like_this(self, app): + def test_invoke_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) @@ -353,7 +354,7 @@ class TestMessageMoreLikeThisApi: with pytest.raises(CompletionRequestError): method(installed_app, "mid") - def test_unexpected_error_more_like_this(self, app): + def test_unexpected_error_more_like_this(self, app: Flask): api = module.MessageMoreLikeThisApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 76c863577a..557fded37e 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,5 +1,7 @@ from unittest.mock import MagicMock, patch +from flask import Flask + import controllers.console.explore.recommended_app as module from models.model import AppMode, IconType @@ -11,7 +13,7 @@ def unwrap(func): class TestRecommendedAppListApi: - def test_get_with_language_param(self, app): + def test_get_with_language_param(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -31,7 +33,7 @@ class TestRecommendedAppListApi: service_mock.assert_called_once_with("en-US") assert result == result_data - def test_get_fallback_to_user_language(self, app): + def test_get_fallback_to_user_language(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -51,7 +53,7 @@ class TestRecommendedAppListApi: service_mock.assert_called_once_with("fr-FR") assert result == result_data - def test_get_fallback_to_default_language(self, app): + def test_get_fallback_to_default_language(self, app: Flask): api = module.RecommendedAppListApi() method = unwrap(api.get) @@ -73,7 +75,7 @@ class TestRecommendedAppListApi: class TestRecommendedAppApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.RecommendedAppApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index bb7cdd55c4..71241890e9 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, PropertyMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.saved_message as module @@ -42,7 +43,7 @@ def payload_patch(): class TestSavedMessageListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = module.SavedMessageListApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 3625056af9..14f00e6295 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -88,7 +89,7 @@ def valid_parameters(): class TestTrialAppWorkflowRunApi: - def test_not_workflow_app(self, app): + def test_not_workflow_app(self, app: Flask): api = module.TrialAppWorkflowRunApi() method = unwrap(api.post) @@ -224,7 +225,7 @@ class TestTrialAppWorkflowRunApi: class TestTrialChatApi: - def test_not_chat_app(self, app): + def test_not_chat_app(self, app: Flask): api = module.TrialChatApi() method = unwrap(api.post) @@ -408,7 +409,7 @@ class TestTrialChatApi: class TestTrialCompletionApi: - def test_not_completion_app(self, app): + def test_not_completion_app(self, app: Flask): api = module.TrialCompletionApi() method = unwrap(api.post) @@ -560,7 +561,7 @@ class TestTrialCompletionApi: class TestTrialMessageSuggestedQuestionApi: - def test_not_chat_app(self, app): + def test_not_chat_app(self, app: Flask): api = module.TrialMessageSuggestedQuestionApi() method = unwrap(api.get) @@ -952,7 +953,7 @@ class TestTrialAppWorkflowTaskStopApi: class TestTrialSitApi: - def test_no_site(self, app): + def test_no_site(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) app_model = MagicMock() @@ -963,7 +964,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_archived_tenant(self, app): + def test_archived_tenant(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) @@ -978,7 +979,7 @@ class TestTrialSitApi: with pytest.raises(Forbidden): method(api, app_model) - def test_success(self, app): + def test_success(self, app: Flask): api = module.TrialSitApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index a26d171649..8b47da25fb 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -73,7 +73,7 @@ def payload_patch(): class TestTagListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = TagListApi() method = unwrap(api.get) @@ -124,7 +124,7 @@ class TestTagListApi: assert result["name"] == "test-tag" assert result["binding_count"] == "0" - def test_post_forbidden(self, app, readonly_user, payload_patch): + def test_post_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagListApi() method = unwrap(api.post) @@ -170,7 +170,7 @@ class TestTagUpdateDeleteApi: assert status == 200 assert result["binding_count"] == "3" - def test_patch_forbidden(self, app, readonly_user, payload_patch): + def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagUpdateDeleteApi() method = unwrap(api.patch) @@ -231,7 +231,7 @@ class TestTagBindingCollectionApi: assert status == 200 assert result["result"] == "success" - def test_create_forbidden(self, app, readonly_user, payload_patch): + def test_create_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagBindingCollectionApi() method = unwrap(api.post) @@ -275,7 +275,7 @@ class TestTagBindingRemoveApi: assert status == 200 assert result["result"] == "success" - def test_remove_forbidden(self, app, readonly_user, payload_patch): + def test_remove_forbidden(self, app: Flask, readonly_user, payload_patch): api = TagBindingRemoveApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index 5df9daa7f8..eebc6f9d60 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -82,7 +82,7 @@ def mock_file_service(mock_db): class TestFileApiGet: - def test_get_upload_config(self, app): + def test_get_upload_config(self, app: Flask): api = FileApi() get_method = unwrap(api.get) @@ -290,7 +290,7 @@ class TestFilePreviewApi: class TestFileSupportTypeApi: - def test_get_supported_types(self, app): + def test_get_supported_types(self, app: Flask): api = FileSupportTypeApi() get_method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/test_workspace_account.py b/api/tests/unit_tests/controllers/console/test_workspace_account.py index 0b1a32581a..4b4f968c8f 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_account.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_account.py @@ -58,7 +58,7 @@ class TestChangeEmailSend: mock_get_change_data, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("current@example.com", "acc1") @@ -107,7 +107,7 @@ class TestChangeEmailSend: mock_get_change_data, mock_current_account, mock_db, - app, + app: Flask, ): """GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step.""" from controllers.console.auth.error import InvalidTokenError @@ -155,7 +155,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_account = _build_account("user@example.com", "acc2") @@ -214,7 +214,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) mock_current_account.return_value = (_build_account("old@example.com", "acc"), None) @@ -267,7 +267,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): """A token whose phase marker is a string but not a known transition must be rejected.""" from controllers.console.auth.error import InvalidTokenError @@ -316,7 +316,7 @@ class TestChangeEmailValidity: mock_reset_rate, mock_current_account, mock_db, - app, + app: Flask, ): """A token minted without a phase marker (e.g. a hand-crafted token) must not validate.""" from controllers.console.auth.error import InvalidTokenError @@ -366,7 +366,7 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): mock_features.return_value = SimpleNamespace(enable_change_email=True) current_user = _build_account("old@example.com", "acc3") @@ -418,7 +418,7 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): """GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset.""" from controllers.console.auth.error import InvalidTokenError @@ -471,7 +471,7 @@ class TestChangeEmailReset: mock_send_notify, mock_current_account, mock_db, - app, + app: Flask, ): """A verified token for address A must not be replayed to change to address B.""" from controllers.console.auth.error import InvalidTokenError @@ -547,7 +547,7 @@ class TestAccountServiceSendChangeEmailEmail: class TestAccountDeletionFeedback: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback") - def test_should_normalize_feedback_email(self, mock_update, mock_db, app): + def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask): with app.test_request_context( "/account/delete/feedback", method="POST", @@ -563,7 +563,7 @@ class TestCheckEmailUnique: @patch("controllers.console.wraps.db") @patch("controllers.console.workspace.account.AccountService.check_email_unique") @patch("controllers.console.workspace.account.AccountService.is_account_in_freeze") - def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app): + def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask): mock_is_freeze.return_value = False mock_check_unique.return_value = True diff --git a/api/tests/unit_tests/controllers/console/test_workspace_members.py b/api/tests/unit_tests/controllers/console/test_workspace_members.py index 811bf5b1e7..412d6a6c52 100644 --- a/api/tests/unit_tests/controllers/console/test_workspace_members.py +++ b/api/tests/unit_tests/controllers/console/test_workspace_members.py @@ -43,7 +43,7 @@ class TestMemberInviteEmailApi: mock_current_account, mock_invite_member, mock_get_features, - app, + app: Flask, ): mock_get_features.return_value = _build_feature_flags() mock_invite_member.return_value = "token-abc" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py index bbe9d09521..064726da05 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -41,7 +42,7 @@ def unwrap(func): class TestAccountInitApi: - def test_init_success(self, app): + def test_init_success(self, app: Flask): api = AccountInitApi() method = unwrap(api.post) @@ -64,7 +65,7 @@ class TestAccountInitApi: assert resp["result"] == "success" - def test_init_already_initialized(self, app): + def test_init_already_initialized(self, app: Flask): api = AccountInitApi() method = unwrap(api.post) @@ -79,7 +80,7 @@ class TestAccountInitApi: class TestAccountProfileApi: - def test_get_profile_success(self, app): + def test_get_profile_success(self, app: Flask): api = AccountProfileApi() method = unwrap(api.get) @@ -140,7 +141,7 @@ class TestAccountUpdateApis: class TestAccountAvatarApiGet: """GET /account/avatar must not sign arbitrary upload_file IDs (IDOR).""" - def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app): + def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -172,7 +173,7 @@ class TestAccountAvatarApiGet: assert result == {"avatar_url": "https://signed/example"} sign_mock.assert_called_once_with(upload_file_id=file_id) - def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app): + def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -204,7 +205,7 @@ class TestAccountAvatarApiGet: sign_mock.assert_not_called() - def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app): + def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -236,7 +237,7 @@ class TestAccountAvatarApiGet: sign_mock.assert_not_called() - def test_get_avatar_https_pass_through_without_signing(self, app): + def test_get_avatar_https_pass_through_without_signing(self, app: Flask): api = AccountAvatarApi() method = unwrap(api.get) @@ -263,7 +264,7 @@ class TestAccountAvatarApiGet: class TestAccountPasswordApi: - def test_password_success(self, app): + def test_password_success(self, app: Flask): api = AccountPasswordApi() method = unwrap(api.post) @@ -292,7 +293,7 @@ class TestAccountPasswordApi: assert result["id"] == "u1" - def test_password_wrong_current(self, app): + def test_password_wrong_current(self, app: Flask): api = AccountPasswordApi() method = unwrap(api.post) @@ -317,7 +318,7 @@ class TestAccountPasswordApi: class TestAccountIntegrateApi: - def test_get_integrates(self, app): + def test_get_integrates(self, app: Flask): api = AccountIntegrateApi() method = unwrap(api.get) @@ -336,7 +337,7 @@ class TestAccountIntegrateApi: class TestAccountDeleteApi: - def test_delete_verify_success(self, app): + def test_delete_verify_success(self, app: Flask): api = AccountDeleteVerifyApi() method = unwrap(api.get) @@ -358,7 +359,7 @@ class TestAccountDeleteApi: assert result["result"] == "success" - def test_delete_invalid_code(self, app): + def test_delete_invalid_code(self, app: Flask): api = AccountDeleteApi() method = unwrap(api.post) @@ -379,7 +380,7 @@ class TestAccountDeleteApi: class TestChangeEmailApis: - def test_check_email_code_invalid(self, app): + def test_check_email_code_invalid(self, app: Flask): api = ChangeEmailCheckApi() method = unwrap(api.post) @@ -405,7 +406,7 @@ class TestChangeEmailApis: with pytest.raises(EmailCodeError): method(api) - def test_reset_email_already_used(self, app): + def test_reset_email_already_used(self, app: Flask): api = ChangeEmailResetApi() method = unwrap(api.post) @@ -427,7 +428,7 @@ class TestChangeEmailApis: class TestCheckEmailUniqueApi: - def test_email_unique_success(self, app): + def test_email_unique_success(self, app: Flask): api = CheckEmailUnique() method = unwrap(api.post) @@ -448,7 +449,7 @@ class TestCheckEmailUniqueApi: assert result["result"] == "success" - def test_email_in_freeze(self, app): + def test_email_in_freeze(self, app: Flask): api = CheckEmailUnique() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py index b4e03f681d..eb0ca15d2e 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.error import AccountNotFound from controllers.console.workspace.agent_providers import ( @@ -16,7 +17,7 @@ def unwrap(func): class TestAgentProviderListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -39,7 +40,7 @@ class TestAgentProviderListApi: assert result == providers - def test_get_empty_list(self, app): + def test_get_empty_list(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -61,7 +62,7 @@ class TestAgentProviderListApi: assert result == [] - def test_get_account_not_found(self, app): + def test_get_account_not_found(self, app: Flask): api = AgentProviderListApi() method = unwrap(api.get) @@ -77,7 +78,7 @@ class TestAgentProviderListApi: class TestAgentProviderApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) @@ -101,7 +102,7 @@ class TestAgentProviderApi: assert result == provider_data - def test_get_provider_not_found(self, app): + def test_get_provider_not_found(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) @@ -124,7 +125,7 @@ class TestAgentProviderApi: assert result is None - def test_get_account_not_found(self, app): + def test_get_account_not_found(self, app: Flask): api = AgentProviderApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py index 0b3d7ef6d7..ed7b2d606f 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console import console_ns from controllers.console.workspace.endpoint import ( @@ -39,7 +40,7 @@ def patch_current_account(user_and_tenant): @pytest.mark.usefixtures("patch_current_account") class TestEndpointCollectionApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = EndpointCollectionApi() method = unwrap(api.post) @@ -57,7 +58,7 @@ class TestEndpointCollectionApi: assert result["success"] is True - def test_create_permission_denied(self, app): + def test_create_permission_denied(self, app: Flask): api = EndpointCollectionApi() method = unwrap(api.post) @@ -77,7 +78,7 @@ class TestEndpointCollectionApi: with pytest.raises(ValueError): method(api) - def test_create_validation_error(self, app): + def test_create_validation_error(self, app: Flask): api = EndpointCollectionApi() method = unwrap(api.post) @@ -96,7 +97,7 @@ class TestEndpointCollectionApi: @pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointCreateApi: - def test_create_success(self, app): + def test_create_success(self, app: Flask): api = DeprecatedEndpointCreateApi() method = unwrap(api.post) @@ -117,7 +118,7 @@ class TestDeprecatedEndpointCreateApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointListApi: - def test_list_success(self, app): + def test_list_success(self, app: Flask): api = EndpointListApi() method = unwrap(api.get) @@ -130,7 +131,7 @@ class TestEndpointListApi: assert "endpoints" in result assert len(result["endpoints"]) == 1 - def test_list_invalid_query(self, app): + def test_list_invalid_query(self, app: Flask): api = EndpointListApi() method = unwrap(api.get) @@ -143,7 +144,7 @@ class TestEndpointListApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointListForSinglePluginApi: - def test_list_for_plugin_success(self, app): + def test_list_for_plugin_success(self, app: Flask): api = EndpointListForSinglePluginApi() method = unwrap(api.get) @@ -158,7 +159,7 @@ class TestEndpointListForSinglePluginApi: assert "endpoints" in result - def test_list_for_plugin_missing_param(self, app): + def test_list_for_plugin_missing_param(self, app: Flask): api = EndpointListForSinglePluginApi() method = unwrap(api.get) @@ -171,7 +172,7 @@ class TestEndpointListForSinglePluginApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointItemApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = EndpointItemApi() method = unwrap(api.delete) @@ -187,7 +188,7 @@ class TestEndpointItemApi: assert result["success"] is True mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1") - def test_delete_service_failure(self, app): + def test_delete_service_failure(self, app: Flask): api = EndpointItemApi() method = unwrap(api.delete) @@ -199,7 +200,7 @@ class TestEndpointItemApi: assert result["success"] is False - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = EndpointItemApi() method = unwrap(api.patch) @@ -226,7 +227,7 @@ class TestEndpointItemApi: settings={"x": 1}, ) - def test_update_validation_error(self, app): + def test_update_validation_error(self, app: Flask): api = EndpointItemApi() method = unwrap(api.patch) @@ -238,7 +239,7 @@ class TestEndpointItemApi: with pytest.raises(ValueError): method(api, "e1") - def test_update_service_failure(self, app): + def test_update_service_failure(self, app: Flask): api = EndpointItemApi() method = unwrap(api.patch) @@ -258,7 +259,7 @@ class TestEndpointItemApi: @pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointDeleteApi: - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) @@ -272,7 +273,7 @@ class TestDeprecatedEndpointDeleteApi: assert result["success"] is True - def test_delete_invalid_payload(self, app): + def test_delete_invalid_payload(self, app: Flask): api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) @@ -282,7 +283,7 @@ class TestDeprecatedEndpointDeleteApi: with pytest.raises(ValueError): method(api) - def test_delete_service_failure(self, app): + def test_delete_service_failure(self, app: Flask): api = DeprecatedEndpointDeleteApi() method = unwrap(api.post) @@ -299,7 +300,7 @@ class TestDeprecatedEndpointDeleteApi: @pytest.mark.usefixtures("patch_current_account") class TestDeprecatedEndpointUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) @@ -317,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi: assert result["success"] is True - def test_update_validation_error(self, app): + def test_update_validation_error(self, app: Flask): api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) @@ -329,7 +330,7 @@ class TestDeprecatedEndpointUpdateApi: with pytest.raises(ValueError): method(api) - def test_update_service_failure(self, app): + def test_update_service_failure(self, app: Flask): api = DeprecatedEndpointUpdateApi() method = unwrap(api.post) @@ -380,7 +381,7 @@ class TestEndpointRouteMetadata: @pytest.mark.usefixtures("patch_current_account") class TestEndpointEnableApi: - def test_enable_success(self, app): + def test_enable_success(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -394,7 +395,7 @@ class TestEndpointEnableApi: assert result["success"] is True - def test_enable_invalid_payload(self, app): + def test_enable_invalid_payload(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -404,7 +405,7 @@ class TestEndpointEnableApi: with pytest.raises(ValueError): method(api) - def test_enable_service_failure(self, app): + def test_enable_service_failure(self, app: Flask): api = EndpointEnableApi() method = unwrap(api.post) @@ -421,7 +422,7 @@ class TestEndpointEnableApi: @pytest.mark.usefixtures("patch_current_account") class TestEndpointDisableApi: - def test_disable_success(self, app): + def test_disable_success(self, app: Flask): api = EndpointDisableApi() method = unwrap(api.post) @@ -435,7 +436,7 @@ class TestEndpointDisableApi: assert result["success"] is True - def test_disable_invalid_payload(self, app): + def test_disable_invalid_payload(self, app: Flask): api = EndpointDisableApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py index 718b57ba6b..0788ff603c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_members.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import HTTPException import services @@ -34,7 +35,7 @@ def unwrap(func): class TestMemberListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = MemberListApi() method = unwrap(api.get) @@ -59,7 +60,7 @@ class TestMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 - def test_get_no_tenant(self, app): + def test_get_no_tenant(self, app: Flask): api = MemberListApi() method = unwrap(api.get) @@ -74,7 +75,7 @@ class TestMemberListApi: class TestMemberInviteEmailApi: - def test_invite_success(self, app): + def test_invite_success(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestMemberInviteEmailApi: assert status == 201 assert result["result"] == "success" - def test_invite_limit_exceeded(self, app): + def test_invite_limit_exceeded(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -123,7 +124,7 @@ class TestMemberInviteEmailApi: with pytest.raises(WorkspaceMembersLimitExceeded): method(api) - def test_invite_already_member(self, app): + def test_invite_already_member(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -151,7 +152,7 @@ class TestMemberInviteEmailApi: assert result["invitation_results"][0]["status"] == "success" - def test_invite_invalid_role(self, app): + def test_invite_invalid_role(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -166,7 +167,7 @@ class TestMemberInviteEmailApi: assert status == 400 assert result["code"] == "invalid-role" - def test_invite_generic_exception(self, app): + def test_invite_generic_exception(self, app: Flask): api = MemberInviteEmailApi() method = unwrap(api.post) @@ -196,7 +197,7 @@ class TestMemberInviteEmailApi: class TestMemberCancelInviteApi: - def test_cancel_success(self, app): + def test_cancel_success(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -216,7 +217,7 @@ class TestMemberCancelInviteApi: assert status == 200 assert result["result"] == "success" - def test_cancel_not_found(self, app): + def test_cancel_not_found(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -233,7 +234,7 @@ class TestMemberCancelInviteApi: with pytest.raises(HTTPException): method(api, "x") - def test_cancel_cannot_operate_self(self, app): + def test_cancel_cannot_operate_self(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -255,7 +256,7 @@ class TestMemberCancelInviteApi: assert status == 400 - def test_cancel_no_permission(self, app): + def test_cancel_no_permission(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -277,7 +278,7 @@ class TestMemberCancelInviteApi: assert status == 403 - def test_cancel_member_not_in_tenant(self, app): + def test_cancel_member_not_in_tenant(self, app: Flask): api = MemberCancelInviteApi() method = unwrap(api.delete) @@ -301,7 +302,7 @@ class TestMemberCancelInviteApi: class TestMemberUpdateRoleApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -324,7 +325,7 @@ class TestMemberUpdateRoleApi: assert result["result"] == "success" - def test_update_invalid_role(self, app): + def test_update_invalid_role(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -335,7 +336,7 @@ class TestMemberUpdateRoleApi: assert status == 400 - def test_update_member_not_found(self, app): + def test_update_member_not_found(self, app: Flask): api = MemberUpdateRoleApi() method = unwrap(api.put) @@ -354,7 +355,7 @@ class TestMemberUpdateRoleApi: class TestDatasetOperatorMemberListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = DatasetOperatorMemberListApi() method = unwrap(api.get) @@ -381,7 +382,7 @@ class TestDatasetOperatorMemberListApi: assert status == 200 assert len(result["accounts"]) == 1 - def test_get_no_tenant(self, app): + def test_get_no_tenant(self, app: Flask): api = DatasetOperatorMemberListApi() method = unwrap(api.get) @@ -396,7 +397,7 @@ class TestDatasetOperatorMemberListApi: class TestSendOwnerTransferEmailApi: - def test_send_success(self, app): + def test_send_success(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -419,7 +420,7 @@ class TestSendOwnerTransferEmailApi: assert result["result"] == "success" - def test_send_ip_limit(self, app): + def test_send_ip_limit(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -433,7 +434,7 @@ class TestSendOwnerTransferEmailApi: with pytest.raises(EmailSendIpLimitError): method(api) - def test_send_not_owner(self, app): + def test_send_not_owner(self, app: Flask): api = SendOwnerTransferEmailApi() method = unwrap(api.post) @@ -452,7 +453,7 @@ class TestSendOwnerTransferEmailApi: class TestOwnerTransferCheckApi: - def test_check_invalid_code(self, app): + def test_check_invalid_code(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -477,7 +478,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(EmailCodeError): method(api) - def test_rate_limited(self, app): + def test_rate_limited(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -498,7 +499,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(OwnerTransferLimitError): method(api) - def test_invalid_token(self, app): + def test_invalid_token(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -520,7 +521,7 @@ class TestOwnerTransferCheckApi: with pytest.raises(InvalidTokenError): method(api) - def test_invalid_email(self, app): + def test_invalid_email(self, app: Flask): api = OwnerTransferCheckApi() method = unwrap(api.post) @@ -547,7 +548,7 @@ class TestOwnerTransferCheckApi: class TestOwnerTransferApi: - def test_transfer_self(self, app): + def test_transfer_self(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) @@ -564,7 +565,7 @@ class TestOwnerTransferApi: with pytest.raises(CannotTransferOwnerToSelfError): method(api, "1") - def test_invalid_token(self, app): + def test_invalid_token(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) @@ -582,7 +583,7 @@ class TestOwnerTransferApi: with pytest.raises(InvalidTokenError): method(api, "2") - def test_member_not_in_tenant(self, app): + def test_member_not_in_tenant(self, app: Flask): api = OwnerTransfer() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index 168479af1e..e836a3cc55 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from flask import Flask from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -26,7 +27,7 @@ def unwrap(func): class TestModelProviderListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ModelProviderListApi() method = unwrap(api.get) @@ -47,7 +48,7 @@ class TestModelProviderListApi: class TestModelProviderCredentialApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.get) @@ -66,7 +67,7 @@ class TestModelProviderCredentialApi: assert "credentials" in result - def test_get_invalid_uuid(self, app): + def test_get_invalid_uuid(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.get) @@ -80,7 +81,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValidationError): method(api, provider="openai") - def test_post_create_success(self, app): + def test_post_create_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.post) @@ -102,7 +103,7 @@ class TestModelProviderCredentialApi: assert result["result"] == "success" assert status == 201 - def test_post_create_validation_error(self, app): + def test_post_create_validation_error(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.post) @@ -122,7 +123,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValueError): method(api, provider="openai") - def test_put_update_success(self, app): + def test_put_update_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.put) @@ -143,7 +144,7 @@ class TestModelProviderCredentialApi: assert result["result"] == "success" - def test_put_invalid_uuid(self, app): + def test_put_invalid_uuid(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.put) @@ -159,7 +160,7 @@ class TestModelProviderCredentialApi: with pytest.raises(ValidationError): method(api, provider="openai") - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ModelProviderCredentialApi() method = unwrap(api.delete) @@ -183,7 +184,7 @@ class TestModelProviderCredentialApi: class TestModelProviderCredentialSwitchApi: - def test_switch_success(self, app): + def test_switch_success(self, app: Flask): api = ModelProviderCredentialSwitchApi() method = unwrap(api.post) @@ -204,7 +205,7 @@ class TestModelProviderCredentialSwitchApi: assert result["result"] == "success" - def test_switch_invalid_uuid(self, app): + def test_switch_invalid_uuid(self, app: Flask): api = ModelProviderCredentialSwitchApi() method = unwrap(api.post) @@ -222,7 +223,7 @@ class TestModelProviderCredentialSwitchApi: class TestModelProviderValidateApi: - def test_validate_success(self, app): + def test_validate_success(self, app: Flask): api = ModelProviderValidateApi() method = unwrap(api.post) @@ -243,7 +244,7 @@ class TestModelProviderValidateApi: assert result["result"] == "success" - def test_validate_failure(self, app): + def test_validate_failure(self, app: Flask): api = ModelProviderValidateApi() method = unwrap(api.post) @@ -266,7 +267,7 @@ class TestModelProviderValidateApi: class TestModelProviderIconApi: - def test_icon_success(self, app): + def test_icon_success(self, app: Flask): api = ModelProviderIconApi() with ( @@ -280,7 +281,7 @@ class TestModelProviderIconApi: assert response.mimetype == "image/png" - def test_icon_not_found(self, app): + def test_icon_not_found(self, app: Flask): api = ModelProviderIconApi() with ( @@ -295,7 +296,7 @@ class TestModelProviderIconApi: class TestPreferredProviderTypeUpdateApi: - def test_update_success(self, app): + def test_update_success(self, app: Flask): api = PreferredProviderTypeUpdateApi() method = unwrap(api.post) @@ -316,7 +317,7 @@ class TestPreferredProviderTypeUpdateApi: assert result["result"] == "success" - def test_invalid_enum(self, app): + def test_invalid_enum(self, app: Flask): api = PreferredProviderTypeUpdateApi() method = unwrap(api.post) @@ -334,7 +335,7 @@ class TestPreferredProviderTypeUpdateApi: class TestModelProviderPaymentCheckoutUrlApi: - def test_checkout_success(self, app): + def test_checkout_success(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) @@ -359,7 +360,7 @@ class TestModelProviderPaymentCheckoutUrlApi: assert "url" in result - def test_invalid_provider(self, app): + def test_invalid_provider(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) @@ -367,7 +368,7 @@ class TestModelProviderPaymentCheckoutUrlApi: with pytest.raises(ValueError): method(api, provider="openai") - def test_permission_denied(self, app): + def test_permission_denied(self, app: Flask): api = ModelProviderPaymentCheckoutUrlApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index f0d32f81fb..4246e3c04c 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -72,7 +72,7 @@ class TestDefaultModelApi: assert result["result"] == "success" - def test_get_returns_empty_when_no_default(self, app): + def test_get_returns_empty_when_no_default(self, app: Flask): api = DefaultModelApi() method = unwrap(api.get) @@ -154,7 +154,7 @@ class TestModelProviderModelApi: assert status == 204 - def test_get_models_returns_empty(self, app): + def test_get_models_returns_empty(self, app: Flask): api = ModelProviderModelApi() method = unwrap(api.get) @@ -224,7 +224,7 @@ class TestModelProviderModelCredentialApi: assert status == 201 - def test_get_empty_credentials(self, app): + def test_get_empty_credentials(self, app: Flask): api = ModelProviderModelCredentialApi() method = unwrap(api.get) @@ -242,7 +242,7 @@ class TestModelProviderModelCredentialApi: assert result["credentials"] == {} - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = ModelProviderModelCredentialApi() method = unwrap(api.delete) @@ -416,7 +416,7 @@ class TestParameterAndAvailableModels: assert "data" in result - def test_empty_rules(self, app): + def test_empty_rules(self, app: Flask): api = ModelProviderModelParameterRuleApi() method = unwrap(api.get) @@ -431,7 +431,7 @@ class TestParameterAndAvailableModels: assert result["data"] == [] - def test_no_models(self, app): + def test_no_models(self, app: Flask): api = ModelProviderAvailableModelApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index ce5fd1c466..d01bf7d668 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -2,6 +2,7 @@ import io from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -61,7 +62,7 @@ def tenant(): class TestPluginListLatestVersionsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginListLatestVersionsApi() method = unwrap(api.post) @@ -77,7 +78,7 @@ class TestPluginListLatestVersionsApi: assert "versions" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginListLatestVersionsApi() method = unwrap(api.post) @@ -95,7 +96,7 @@ class TestPluginListLatestVersionsApi: class TestPluginDebuggingKeyApi: - def test_debugging_key_success(self, app): + def test_debugging_key_success(self, app: Flask): api = PluginDebuggingKeyApi() method = unwrap(api.get) @@ -108,7 +109,7 @@ class TestPluginDebuggingKeyApi: assert result["key"] == "k" - def test_debugging_key_error(self, app): + def test_debugging_key_error(self, app: Flask): api = PluginDebuggingKeyApi() method = unwrap(api.get) @@ -125,7 +126,7 @@ class TestPluginDebuggingKeyApi: class TestPluginListApi: - def test_plugin_list(self, app): + def test_plugin_list(self, app: Flask): api = PluginListApi() method = unwrap(api.get) @@ -142,7 +143,7 @@ class TestPluginListApi: class TestPluginIconApi: - def test_plugin_icon(self, app): + def test_plugin_icon(self, app: Flask): api = PluginIconApi() method = unwrap(api.get) @@ -156,7 +157,7 @@ class TestPluginIconApi: class TestPluginAssetApi: - def test_plugin_asset(self, app): + def test_plugin_asset(self, app: Flask): api = PluginAssetApi() method = unwrap(api.get) @@ -171,7 +172,7 @@ class TestPluginAssetApi: class TestPluginUploadFromPkgApi: - def test_upload_pkg_success(self, app): + def test_upload_pkg_success(self, app: Flask): api = PluginUploadFromPkgApi() method = unwrap(api.post) @@ -188,7 +189,7 @@ class TestPluginUploadFromPkgApi: assert result["ok"] is True - def test_upload_pkg_too_large(self, app): + def test_upload_pkg_too_large(self, app: Flask): api = PluginUploadFromPkgApi() method = unwrap(api.post) @@ -210,7 +211,7 @@ class TestPluginUploadFromPkgApi: class TestPluginInstallFromPkgApi: - def test_install_from_pkg(self, app): + def test_install_from_pkg(self, app: Flask): api = PluginInstallFromPkgApi() method = unwrap(api.post) @@ -229,7 +230,7 @@ class TestPluginInstallFromPkgApi: class TestPluginUninstallApi: - def test_uninstall(self, app): + def test_uninstall(self, app: Flask): api = PluginUninstallApi() method = unwrap(api.post) @@ -246,7 +247,7 @@ class TestPluginUninstallApi: class TestPluginChangePermissionApi: - def test_change_permission_forbidden(self, app): + def test_change_permission_forbidden(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) @@ -264,7 +265,7 @@ class TestPluginChangePermissionApi: with pytest.raises(Forbidden): method(api) - def test_change_permission_success(self, app): + def test_change_permission_success(self, app: Flask): api = PluginChangePermissionApi() method = unwrap(api.post) @@ -286,7 +287,7 @@ class TestPluginChangePermissionApi: class TestPluginFetchPermissionApi: - def test_fetch_permission_default(self, app): + def test_fetch_permission_default(self, app: Flask): api = PluginFetchPermissionApi() method = unwrap(api.get) @@ -319,7 +320,7 @@ class TestPluginFetchDynamicSelectOptionsApi: class TestPluginReadmeApi: - def test_fetch_readme(self, app): + def test_fetch_readme(self, app: Flask): api = PluginReadmeApi() method = unwrap(api.get) @@ -334,7 +335,7 @@ class TestPluginReadmeApi: class TestPluginListInstallationsFromIdsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -352,7 +353,7 @@ class TestPluginListInstallationsFromIdsApi: assert "plugins" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginListInstallationsFromIdsApi() method = unwrap(api.post) @@ -371,7 +372,7 @@ class TestPluginListInstallationsFromIdsApi: class TestPluginUploadFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -388,7 +389,7 @@ class TestPluginUploadFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUploadFromGithubApi() method = unwrap(api.post) @@ -407,7 +408,7 @@ class TestPluginUploadFromGithubApi: class TestPluginUploadFromBundleApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUploadFromBundleApi() method = unwrap(api.post) @@ -430,7 +431,7 @@ class TestPluginUploadFromBundleApi: assert result["ok"] is True - def test_too_large(self, app): + def test_too_large(self, app: Flask): api = PluginUploadFromBundleApi() method = unwrap(api.post) @@ -458,7 +459,7 @@ class TestPluginUploadFromBundleApi: class TestPluginInstallFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginInstallFromGithubApi() method = unwrap(api.post) @@ -478,7 +479,7 @@ class TestPluginInstallFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginInstallFromGithubApi() method = unwrap(api.post) @@ -502,7 +503,7 @@ class TestPluginInstallFromGithubApi: class TestPluginInstallFromMarketplaceApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginInstallFromMarketplaceApi() method = unwrap(api.post) @@ -520,7 +521,7 @@ class TestPluginInstallFromMarketplaceApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginInstallFromMarketplaceApi() method = unwrap(api.post) @@ -539,7 +540,7 @@ class TestPluginInstallFromMarketplaceApi: class TestPluginFetchMarketplacePkgApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchMarketplacePkgApi() method = unwrap(api.get) @@ -552,7 +553,7 @@ class TestPluginFetchMarketplacePkgApi: assert "manifest" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchMarketplacePkgApi() method = unwrap(api.get) @@ -569,7 +570,7 @@ class TestPluginFetchMarketplacePkgApi: class TestPluginFetchManifestApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchManifestApi() method = unwrap(api.get) @@ -585,7 +586,7 @@ class TestPluginFetchManifestApi: assert "manifest" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchManifestApi() method = unwrap(api.get) @@ -602,7 +603,7 @@ class TestPluginFetchManifestApi: class TestPluginFetchInstallTasksApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchInstallTasksApi() method = unwrap(api.get) @@ -615,7 +616,7 @@ class TestPluginFetchInstallTasksApi: assert "tasks" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchInstallTasksApi() method = unwrap(api.get) @@ -632,7 +633,7 @@ class TestPluginFetchInstallTasksApi: class TestPluginFetchInstallTaskApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchInstallTaskApi() method = unwrap(api.get) @@ -645,7 +646,7 @@ class TestPluginFetchInstallTaskApi: assert "task" in result - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchInstallTaskApi() method = unwrap(api.get) @@ -662,7 +663,7 @@ class TestPluginFetchInstallTaskApi: class TestPluginDeleteInstallTaskApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteInstallTaskApi() method = unwrap(api.post) @@ -675,7 +676,7 @@ class TestPluginDeleteInstallTaskApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteInstallTaskApi() method = unwrap(api.post) @@ -692,7 +693,7 @@ class TestPluginDeleteInstallTaskApi: class TestPluginDeleteAllInstallTaskItemsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteAllInstallTaskItemsApi() method = unwrap(api.post) @@ -707,7 +708,7 @@ class TestPluginDeleteAllInstallTaskItemsApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteAllInstallTaskItemsApi() method = unwrap(api.post) @@ -724,7 +725,7 @@ class TestPluginDeleteAllInstallTaskItemsApi: class TestPluginDeleteInstallTaskItemApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginDeleteInstallTaskItemApi() method = unwrap(api.post) @@ -737,7 +738,7 @@ class TestPluginDeleteInstallTaskItemApi: assert result["success"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginDeleteInstallTaskItemApi() method = unwrap(api.post) @@ -754,7 +755,7 @@ class TestPluginDeleteInstallTaskItemApi: class TestPluginUpgradeFromMarketplaceApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -775,7 +776,7 @@ class TestPluginUpgradeFromMarketplaceApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUpgradeFromMarketplaceApi() method = unwrap(api.post) @@ -797,7 +798,7 @@ class TestPluginUpgradeFromMarketplaceApi: class TestPluginUpgradeFromGithubApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginUpgradeFromGithubApi() method = unwrap(api.post) @@ -821,7 +822,7 @@ class TestPluginUpgradeFromGithubApi: assert result["ok"] is True - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginUpgradeFromGithubApi() method = unwrap(api.post) @@ -846,7 +847,7 @@ class TestPluginUpgradeFromGithubApi: class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) @@ -873,7 +874,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: assert result["options"] == [1] - def test_daemon_error(self, app): + def test_daemon_error(self, app: Flask): api = PluginFetchDynamicSelectOptionsWithCredentialsApi() method = unwrap(api.post) @@ -901,7 +902,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: class TestPluginChangePreferencesApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginChangePreferencesApi() method = unwrap(api.post) @@ -931,7 +932,7 @@ class TestPluginChangePreferencesApi: assert result["success"] is True - def test_permission_fail(self, app): + def test_permission_fail(self, app: Flask): api = PluginChangePreferencesApi() method = unwrap(api.post) @@ -962,7 +963,7 @@ class TestPluginChangePreferencesApi: class TestPluginFetchPreferencesApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginFetchPreferencesApi() method = unwrap(api.get) @@ -996,7 +997,7 @@ class TestPluginFetchPreferencesApi: class TestPluginAutoUpgradeExcludePluginApi: - def test_success(self, app): + def test_success(self, app: Flask): api = PluginAutoUpgradeExcludePluginApi() method = unwrap(api.post) @@ -1011,7 +1012,7 @@ class TestPluginAutoUpgradeExcludePluginApi: assert result["success"] is True - def test_fail(self, app): + def test_fail(self, app: Flask): api = PluginAutoUpgradeExcludePluginApi() method = unwrap(api.post) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index e82a29f045..a52518c2d2 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Unauthorized @@ -37,7 +38,7 @@ def unwrap(func): class TestTenantListApi: - def test_get_success_saas_path(self, app): + def test_get_success_saas_path(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -85,7 +86,7 @@ class TestTenantListApi: get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) get_features_mock.assert_not_called() - def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app): + def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app: Flask): """Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used. billing.enabled is mocked False to prove the endpoint does not gate on it for this path @@ -140,7 +141,7 @@ class TestTenantListApi: get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) get_features_mock.assert_called_once_with("t2") - def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app): + def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask): """Test fallback to FeatureService when bulk billing returns empty result. BillingService.get_plan_bulk catches exceptions internally and returns empty dict, @@ -197,7 +198,7 @@ class TestTenantListApi: assert get_features_mock.call_count == 2 logger_warning_mock.assert_called_once() - def test_get_billing_disabled_community_path(self, app): + def test_get_billing_disabled_community_path(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -236,7 +237,7 @@ class TestTenantListApi: assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX get_features_mock.assert_called_once_with("t1") - def test_get_enterprise_only_skips_feature_service(self, app): + def test_get_enterprise_only_skips_feature_service(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -276,7 +277,7 @@ class TestTenantListApi: assert result["workspaces"][1]["current"] is True get_features_mock.assert_not_called() - def test_get_enterprise_only_with_empty_tenants(self, app): + def test_get_enterprise_only_with_empty_tenants(self, app: Flask): api = TenantListApi() method = unwrap(api.get) @@ -302,7 +303,7 @@ class TestTenantListApi: class TestWorkspaceListApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = WorkspaceListApi() method = unwrap(api.get) @@ -324,7 +325,7 @@ class TestWorkspaceListApi: assert result["total"] == 1 assert result["has_more"] is False - def test_get_has_next_true(self, app): + def test_get_has_next_true(self, app: Flask): api = WorkspaceListApi() method = unwrap(api.get) @@ -355,7 +356,7 @@ class TestWorkspaceListApi: class TestTenantApi: - def test_post_active_tenant(self, app): + def test_post_active_tenant(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -375,7 +376,7 @@ class TestTenantApi: assert status == 200 assert result["id"] == "t1" - def test_post_archived_with_switch(self, app): + def test_post_archived_with_switch(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -397,7 +398,7 @@ class TestTenantApi: assert result["id"] == "new" - def test_post_archived_no_tenant(self, app): + def test_post_archived_no_tenant(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -411,7 +412,7 @@ class TestTenantApi: with pytest.raises(Unauthorized): method(api) - def test_post_info_path(self, app): + def test_post_info_path(self, app: Flask): api = TenantApi() method = unwrap(api.post) @@ -454,7 +455,7 @@ class TestTenantInfoResponse: class TestSwitchWorkspaceApi: - def test_switch_success(self, app): + def test_switch_success(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -477,7 +478,7 @@ class TestSwitchWorkspaceApi: assert result["result"] == "success" - def test_switch_not_linked(self, app): + def test_switch_not_linked(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -493,7 +494,7 @@ class TestSwitchWorkspaceApi: with pytest.raises(AccountNotLinkTenantError): method(api) - def test_switch_tenant_not_found(self, app): + def test_switch_tenant_not_found(self, app: Flask): api = SwitchWorkspaceApi() method = unwrap(api.post) @@ -515,7 +516,7 @@ class TestSwitchWorkspaceApi: class TestCustomConfigWorkspaceApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CustomConfigWorkspaceApi() method = unwrap(api.post) @@ -538,7 +539,7 @@ class TestCustomConfigWorkspaceApi: assert result["result"] == "success" - def test_logo_fallback(self, app): + def test_logo_fallback(self, app: Flask): api = CustomConfigWorkspaceApi() method = unwrap(api.post) @@ -569,7 +570,7 @@ class TestCustomConfigWorkspaceApi: class TestWebappLogoWorkspaceApi: - def test_no_file(self, app): + def test_no_file(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -582,7 +583,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(NoFileUploadedError): method(api) - def test_too_many_files(self, app): + def test_too_many_files(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -601,7 +602,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(TooManyFilesError): method(api) - def test_invalid_extension(self, app): + def test_invalid_extension(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -616,7 +617,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(UnsupportedFileTypeError): method(api) - def test_upload_success(self, app): + def test_upload_success(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -648,7 +649,7 @@ class TestWebappLogoWorkspaceApi: assert status == 201 assert result["id"] == "file1" - def test_filename_missing(self, app): + def test_filename_missing(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -672,7 +673,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(FilenameNotExistsError): method(api) - def test_file_too_large(self, app): + def test_file_too_large(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -701,7 +702,7 @@ class TestWebappLogoWorkspaceApi: with pytest.raises(FileTooLargeError): method(api) - def test_service_unsupported_file(self, app): + def test_service_unsupported_file(self, app: Flask): api = WebappLogoWorkspaceApi() method = unwrap(api.post) @@ -732,7 +733,7 @@ class TestWebappLogoWorkspaceApi: class TestWorkspaceInfoApi: - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = WorkspaceInfoApi() method = unwrap(api.post) @@ -756,7 +757,7 @@ class TestWorkspaceInfoApi: assert result["result"] == "success" - def test_no_current_tenant(self, app): + def test_no_current_tenant(self, app: Flask): api = WorkspaceInfoApi() method = unwrap(api.post) @@ -774,7 +775,7 @@ class TestWorkspaceInfoApi: class TestWorkspacePermissionApi: - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = WorkspacePermissionApi() method = unwrap(api.get) @@ -799,7 +800,7 @@ class TestWorkspacePermissionApi: assert status == 200 assert result["workspace_id"] == "t1" - def test_no_current_tenant(self, app): + def test_no_current_tenant(self, app: Flask): api = WorkspacePermissionApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f5d93b5ac3..ae0edcf382 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -41,7 +41,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_for_chat_app( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving parameters for a chat app.""" # Arrange @@ -91,7 +91,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_for_workflow_app( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving parameters for a workflow app.""" # Arrange @@ -136,7 +136,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_raises_error_when_chat_config_missing( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test that AppUnavailableError is raised when chat app has no config.""" # Arrange @@ -174,7 +174,7 @@ class TestAppParameterApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_parameters_raises_error_when_workflow_missing( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test that AppUnavailableError is raised when workflow app has no workflow.""" # Arrange @@ -234,7 +234,14 @@ class TestAppMetaApi: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.app.app.AppService") def test_get_app_meta( - self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, + mock_app_service, + mock_db, + mock_validate_token, + mock_current_app, + mock_user_logged_in, + app: Flask, + mock_app_model, ): """Test retrieving app metadata via AppService.""" # Arrange @@ -310,7 +317,7 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_app_info( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model ): """Test retrieving basic app information.""" mock_current_app.login_manager = Mock() @@ -402,7 +409,9 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.current_app") @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") - def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app): + def test_get_app_info_with_no_tags( + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask + ): """Test retrieving app info when app has no tags.""" # Arrange mock_current_app.login_manager = Mock() @@ -453,7 +462,7 @@ class TestAppInfoApi: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.db") def test_get_app_info_returns_correct_mode( - self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode + self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, app_mode ): """Test that all app modes are correctly returned.""" # Arrange diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index c16ebad739..4741481ef6 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -190,7 +191,7 @@ class TestAudioServiceMockedBehavior: class TestAudioApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) api = AudioApi() handler = _unwrap(api.post) @@ -216,7 +217,7 @@ class TestAudioApi: (InvokeError("invoke"), CompletionRequestError), ], ) - def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: + def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) api = AudioApi() handler = _unwrap(api.post) @@ -227,7 +228,7 @@ class TestAudioApi: with pytest.raises(expected): handler(api, app_model=app_model, end_user=end_user) - def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_unhandled_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")) ) @@ -242,7 +243,7 @@ class TestAudioApi: class TestTextApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) api = TextApi() @@ -259,7 +260,7 @@ class TestTextApi: assert response == {"audio": "ok"} - def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()) ) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 3364c07e62..259741937f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -295,7 +296,7 @@ class TestCompletionControllerLogic: @patch("controllers.service_api.app.completion.service_api_ns") @patch("controllers.service_api.app.completion.AppGenerateService") - def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app): + def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask): """Test CompletionApi.post success path.""" from controllers.service_api.app.completion import CompletionApi @@ -320,7 +321,7 @@ class TestCompletionControllerLogic: mock_generate_service.generate.assert_called_once() @patch("controllers.service_api.app.completion.service_api_ns") - def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app): + def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask): """Test CompletionApi.post with wrong app mode.""" from controllers.service_api.app.completion import CompletionApi @@ -334,7 +335,7 @@ class TestCompletionControllerLogic: @patch("controllers.service_api.app.completion.service_api_ns") @patch("controllers.service_api.app.completion.AppGenerateService") - def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app): + def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask): """Test ChatApi.post success path.""" from controllers.service_api.app.completion import ChatApi @@ -355,7 +356,7 @@ class TestCompletionControllerLogic: assert response == {"text": "compacted"} @patch("controllers.service_api.app.completion.service_api_ns") - def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app): + def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask): """Test ChatApi.post with wrong app mode.""" from controllers.service_api.app.completion import ChatApi @@ -368,7 +369,7 @@ class TestCompletionControllerLogic: ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user) @patch("controllers.service_api.app.completion.AppTaskService") - def test_completion_stop_api_success(self, mock_task_service, app): + def test_completion_stop_api_success(self, mock_task_service, app: Flask): """Test CompletionStopApi.post success.""" from controllers.service_api.app.completion import CompletionStopApi @@ -385,7 +386,7 @@ class TestCompletionControllerLogic: mock_task_service.stop_task.assert_called_once() @patch("controllers.service_api.app.completion.AppTaskService") - def test_chat_stop_api_success(self, mock_task_service, app): + def test_chat_stop_api_success(self, mock_task_service, app: Flask): """Test ChatStopApi.post success.""" from controllers.service_api.app.completion import ChatStopApi diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index 4fb8ecf784..6dc8f54d42 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -20,6 +20,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound import services @@ -504,7 +505,7 @@ class TestConversationApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user) - def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_list_last_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _BeginStub: def __enter__(self): return SimpleNamespace() @@ -552,7 +553,7 @@ class TestConversationDetailApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "delete", @@ -570,7 +571,7 @@ class TestConversationDetailApiController: class TestConversationRenameApiController: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "rename", @@ -602,7 +603,7 @@ class TestConversationVariablesApiController: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "get_conversational_variable", @@ -621,7 +622,7 @@ class TestConversationVariablesApiController: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001") - def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) monkeypatch.setattr( ConversationService, @@ -661,7 +662,7 @@ class TestConversationVariablesApiController: class TestConversationVariableDetailApiController: - def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_type_mismatch(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "update_conversation_variable", @@ -687,7 +688,7 @@ class TestConversationVariableDetailApiController: variable_id="00000000-0000-0000-0000-000000000002", ) - def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( ConversationService, "update_conversation_variable", @@ -713,7 +714,7 @@ class TestConversationVariableDetailApiController: variable_id="00000000-0000-0000-0000-000000000002", ) - def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_update_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC) monkeypatch.setattr( ConversationService, diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file.py b/api/tests/unit_tests/controllers/service_api/app/test_file.py index 7060bd79df..2615c3edac 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file.py @@ -16,6 +16,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from controllers.common.errors import ( FilenameNotExistsError, @@ -282,7 +283,7 @@ class TestFileApiPost: assert status == 201 mock_file_svc_cls.return_value.upload_file.assert_called_once() - def test_upload_no_file(self, app, mock_app_model, mock_end_user): + def test_upload_no_file(self, app: Flask, mock_app_model, mock_end_user): """Test NoFileUploadedError when no file in request.""" from controllers.service_api.app.file import FileApi @@ -296,7 +297,7 @@ class TestFileApiPost: with pytest.raises(NoFileUploadedError): _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) - def test_upload_too_many_files(self, app, mock_app_model, mock_end_user): + def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user): """Test TooManyFilesError when multiple files uploaded.""" from io import BytesIO @@ -317,7 +318,7 @@ class TestFileApiPost: with pytest.raises(TooManyFilesError): _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) - def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user): + def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user): """Test UnsupportedFileTypeError when file has no mimetype.""" from io import BytesIO diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py index 846d5368f3..510d4a9470 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -11,6 +11,7 @@ from types import SimpleNamespace from unittest.mock import ANY, MagicMock, Mock import pytest +from flask import Flask import services.app_generate_service as ags_module from controllers.service_api.app.workflow_events import WorkflowEventsApi @@ -281,7 +282,7 @@ class TestHitlServiceApi: workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open( - self, app, monkeypatch: pytest.MonkeyPatch + self, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: workflow_run = SimpleNamespace( id="run-1", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index c2b8aed1ae..2bc9771862 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, InternalServerError, NotFound from controllers.service_api.app.error import NotChatAppError @@ -390,7 +391,7 @@ class TestMessageListApi: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user) - def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_conversation_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "pagination_by_first_id", @@ -409,7 +410,7 @@ class TestMessageListApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user) - def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_first_message_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "pagination_by_first_id", @@ -430,7 +431,7 @@ class TestMessageListApi: class TestMessageFeedbackApi: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "create_feedback", @@ -452,7 +453,7 @@ class TestMessageFeedbackApi: class TestAppGetFeedbacksApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"]) api = AppGetFeedbacksApi() @@ -476,7 +477,7 @@ class TestMessageSuggestedApi: with pytest.raises(NotChatAppError): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -492,7 +493,7 @@ class TestMessageSuggestedApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_disabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -508,7 +509,7 @@ class TestMessageSuggestedApi: with pytest.raises(BadRequest): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_internal_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", @@ -524,7 +525,7 @@ class TestMessageSuggestedApi: with pytest.raises(InternalServerError): handler(api, app_model=app_model, end_user=end_user, message_id="m1") - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( MessageService, "get_suggested_questions_after_answer", diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index da09ec13ce..7115ea1e12 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -20,6 +20,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -366,7 +367,7 @@ class TestWorkflowRunRepository: class TestWorkflowRunDetailApi: - def test_not_workflow_app(self, app) -> None: + def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunDetailApi() handler = _unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -397,7 +398,7 @@ class TestWorkflowRunDetailApi: class TestWorkflowRunApi: - def test_not_workflow_app(self, app) -> None: + def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunApi() handler = _unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -407,7 +408,7 @@ class TestWorkflowRunApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user) - def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_rate_limit(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -425,7 +426,7 @@ class TestWorkflowRunApi: class TestWorkflowRunByIdApi: - def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -441,7 +442,7 @@ class TestWorkflowRunByIdApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, workflow_id="w1") - def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_draft_workflow(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AppGenerateService, "generate", @@ -459,7 +460,7 @@ class TestWorkflowRunByIdApi: class TestWorkflowTaskStopApi: - def test_wrong_mode(self, app) -> None: + def test_wrong_mode(self, app: Flask) -> None: api = WorkflowTaskStopApi() handler = _unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) @@ -469,7 +470,7 @@ class TestWorkflowTaskStopApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user, task_id="t1") - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: stop_mock = Mock() send_mock = Mock() monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock) @@ -489,7 +490,7 @@ class TestWorkflowTaskStopApi: class TestWorkflowAppLogApi: - def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: class _BeginStub: def __enter__(self): return SimpleNamespace() @@ -557,7 +558,7 @@ class TestWorkflowRunDetailApiGet: self, mock_db, mock_repo_factory, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow run detail retrieval.""" @@ -579,7 +580,7 @@ class TestWorkflowRunDetailApiGet: assert result["status"] == "succeeded" @patch("controllers.service_api.app.workflow.db") - def test_get_workflow_run_wrong_app_mode(self, mock_db, app): + def test_get_workflow_run_wrong_app_mode(self, mock_db, app: Flask): """Test NotWorkflowAppError when app mode is not workflow or advanced_chat.""" from controllers.service_api.app.workflow import WorkflowRunDetailApi @@ -604,7 +605,7 @@ class TestWorkflowTaskStopApiPost: self, mock_queue_mgr, mock_graph_mgr, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow task stop.""" @@ -624,7 +625,7 @@ class TestWorkflowTaskStopApiPost: mock_graph_mgr.assert_called_once() mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1") - def test_stop_workflow_task_wrong_app_mode(self, app): + def test_stop_workflow_task_wrong_app_mode(self, app: Flask): """Test NotWorkflowAppError when app mode is not workflow.""" from controllers.service_api.app.workflow import WorkflowTaskStopApi @@ -649,7 +650,7 @@ class TestWorkflowAppLogApiGet: self, mock_db, mock_wf_svc_cls, - app, + app: Flask, mock_workflow_app, ): """Test successful workflow log retrieval.""" diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py index f45a7f9632..b3edc2ecd8 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -41,7 +42,7 @@ class TestWorkflowEventsApi: with pytest.raises(NotWorkflowAppError): handler(api, app_model=app_model, end_user=end_user, task_id="run-1") - def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: _mock_repo_for_run(monkeypatch, workflow_run=None) api = WorkflowEventsApi() handler = _unwrap(api.get) @@ -52,7 +53,7 @@ class TestWorkflowEventsApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, task_id="run-1") - def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_workflow_run_permission_denied(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", @@ -70,7 +71,7 @@ class TestWorkflowEventsApi: with pytest.raises(NotFound): handler(api, app_model=app_model, end_user=end_user, task_id="run-1") - def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_finished_run_returns_sse(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", @@ -103,7 +104,7 @@ class TestWorkflowEventsApi: assert payload["task_id"] == "run-1" assert payload["event"] == "workflow_finished" - def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_running_run_streams_events(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", @@ -135,7 +136,7 @@ class TestWorkflowEventsApi: ) workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"]) - def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_running_run_with_snapshot(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow_run = SimpleNamespace( id="run-1", app_id="app-1", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py index f33c482d04..362af883ed 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/rag_pipeline/test_rag_pipeline_workflow.py @@ -23,6 +23,7 @@ from datetime import UTC, datetime from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden, NotFound @@ -373,7 +374,7 @@ class TestDatasourcePluginsApiGet: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") - def test_get_plugins_success(self, mock_svc_cls, mock_db, app): + def test_get_plugins_success(self, mock_svc_cls, mock_db, app: Flask): """Test successful retrieval of datasource plugins.""" tenant_id = str(uuid.uuid4()) dataset_id = str(uuid.uuid4()) @@ -396,7 +397,7 @@ class TestDatasourcePluginsApiGet: ) @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_get_plugins_not_found(self, mock_db, app): + def test_get_plugins_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -407,7 +408,7 @@ class TestDatasourcePluginsApiGet: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") - def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app): + def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app: Flask): """Test empty plugin list.""" mock_db.session.scalar.return_value = Mock() mock_svc_instance = Mock() @@ -439,7 +440,7 @@ class TestDatasourceNodeRunApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app): + def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app: Flask): """Test successful datasource node run.""" tenant_id = str(uuid.uuid4()) dataset_id = str(uuid.uuid4()) @@ -473,7 +474,7 @@ class TestDatasourceNodeRunApiPost: mock_svc_instance.run_datasource_workflow_node.assert_called_once() @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_post_not_found(self, mock_db, app): + def test_post_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -488,7 +489,7 @@ class TestDatasourceNodeRunApiPost: ) @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app): + def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app: Flask): """Test AssertionError when current_user is not an Account instance.""" mock_db.session.scalar.return_value = Mock() mock_ns.payload = { @@ -549,7 +550,7 @@ class TestPipelineRunApiPost: mock_gen_svc.generate.assert_called_once() @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_post_not_found(self, mock_db, app): + def test_post_not_found(self, mock_db, app: Flask): """Test NotFound when dataset check fails.""" mock_db.session.scalar.return_value = None @@ -561,7 +562,7 @@ class TestPipelineRunApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns") - def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app): + def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app: Flask): """Test Forbidden when current_user is not an Account.""" mock_db.session.scalar.return_value = Mock() mock_ns.payload = { @@ -585,7 +586,7 @@ class TestFileUploadApiPost: @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user") @patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db") - def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app): + def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app: Flask): """Test successful file upload.""" mock_current_user.__bool__ = Mock(return_value=True) @@ -621,7 +622,7 @@ class TestFileUploadApiPost: assert response["name"] == "doc.pdf" assert response["extension"] == "pdf" - def test_upload_no_file(self, app): + def test_upload_no_file(self, app: Flask): """Test error when no file is uploaded.""" with app.test_request_context( "/datasets/pipeline/file-upload", diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index e9c3e6d376..fe8fc02548 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -18,6 +18,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.dataset.segment import ( @@ -782,7 +783,7 @@ class TestSegmentApiGet: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -893,7 +894,7 @@ class TestSegmentApiPost: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -946,7 +947,7 @@ class TestSegmentApiPost: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -989,7 +990,7 @@ class TestSegmentApiPost: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1041,7 +1042,7 @@ class TestDatasetSegmentApiDelete: mock_doc_svc, mock_dataset_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1086,7 +1087,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_doc_svc, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1282,7 +1283,7 @@ class TestDatasetSegmentApiUpdate: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1322,7 +1323,7 @@ class TestDatasetSegmentApiUpdate: mock_dataset_svc, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1374,7 +1375,7 @@ class TestDatasetSegmentApiGetSingle: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1421,7 +1422,7 @@ class TestDatasetSegmentApiGetSingle: mock_seg_svc, mock_marshal, mock_summary_svc, - app, + app: Flask, mock_tenant, mock_dataset, mock_segment, @@ -1460,7 +1461,7 @@ class TestDatasetSegmentApiGetSingle: self, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1491,7 +1492,7 @@ class TestDatasetSegmentApiGetSingle: mock_account_fn, mock_dataset_svc, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1526,7 +1527,7 @@ class TestDatasetSegmentApiGetSingle: mock_dataset_svc, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1570,7 +1571,7 @@ class TestChildChunkApiGet: mock_doc_svc, mock_seg_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1609,7 +1610,7 @@ class TestChildChunkApiGet: self, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1638,7 +1639,7 @@ class TestChildChunkApiGet: mock_db, mock_account_fn, mock_doc_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1670,7 +1671,7 @@ class TestChildChunkApiGet: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1729,7 +1730,7 @@ class TestChildChunkApiPost: mock_doc_svc, mock_seg_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1771,7 +1772,7 @@ class TestChildChunkApiPost: mock_feature_svc, mock_db, mock_account_fn, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1809,7 +1810,7 @@ class TestChildChunkApiPost: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1863,7 +1864,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1913,7 +1914,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1954,7 +1955,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -1994,7 +1995,7 @@ class TestDatasetChildChunkApiDelete: mock_account_fn, mock_doc_svc, mock_seg_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index b93a1cf14b..b7e24f9201 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -19,6 +19,7 @@ import uuid from unittest.mock import Mock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.service_api.dataset.metadata import ( @@ -76,7 +77,7 @@ class TestDatasetMetadataCreatePost: mock_dataset_svc, mock_meta_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -106,7 +107,7 @@ class TestDatasetMetadataCreatePost: def test_create_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -136,7 +137,7 @@ class TestDatasetMetadataCreateGet: self, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -160,7 +161,7 @@ class TestDatasetMetadataCreateGet: def test_get_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -201,7 +202,7 @@ class TestDatasetMetadataServiceApiPatch: mock_dataset_svc, mock_meta_svc, mock_marshal, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -232,7 +233,7 @@ class TestDatasetMetadataServiceApiPatch: def test_update_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -273,7 +274,7 @@ class TestDatasetMetadataServiceApiDelete: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -302,7 +303,7 @@ class TestDatasetMetadataServiceApiDelete: def test_delete_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -336,7 +337,7 @@ class TestDatasetMetadataBuiltInFieldGet: def test_get_built_in_fields_success( self, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -382,7 +383,7 @@ class TestDatasetMetadataBuiltInFieldAction: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -414,7 +415,7 @@ class TestDatasetMetadataBuiltInFieldAction: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -441,7 +442,7 @@ class TestDatasetMetadataBuiltInFieldAction: def test_action_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -485,7 +486,7 @@ class TestDocumentMetadataEditPost: mock_current_user, mock_dataset_svc, mock_meta_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): @@ -513,7 +514,7 @@ class TestDocumentMetadataEditPost: def test_update_documents_metadata_dataset_not_found( self, mock_dataset_svc, - app, + app: Flask, mock_tenant, mock_dataset, ): diff --git a/api/tests/unit_tests/controllers/service_api/test_index.py b/api/tests/unit_tests/controllers/service_api/test_index.py index c560a3c698..8441118181 100644 --- a/api/tests/unit_tests/controllers/service_api/test_index.py +++ b/api/tests/unit_tests/controllers/service_api/test_index.py @@ -5,6 +5,7 @@ Unit tests for Service API Index endpoint from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.service_api.index import IndexApi @@ -13,7 +14,7 @@ class TestIndexApi: """Test suite for IndexApi resource.""" @patch("controllers.service_api.index.dify_config", autospec=True) - def test_get_returns_api_info(self, mock_config, app): + def test_get_returns_api_info(self, mock_config, app: Flask): """Test that GET returns API metadata with correct structure.""" # Arrange mock_config.project.version = "1.0.0-test" @@ -32,7 +33,7 @@ class TestIndexApi: assert response["api_version"] == "v1" assert response["server_version"] == "1.0.0-test" - def test_get_response_has_required_fields(self, app): + def test_get_response_has_required_fields(self, app: Flask): """Test that response contains all required fields.""" # Arrange mock_config = MagicMock() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 6dfbdcf98e..30d7b92913 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -39,7 +39,7 @@ class TestValidateAndGetApiToken: app.config["TESTING"] = True return app - def test_missing_authorization_header(self, app): + def test_missing_authorization_header(self, app: Flask): """Test that Unauthorized is raised when Authorization header is missing.""" # Arrange with app.test_request_context("/", method="GET"): @@ -50,7 +50,7 @@ class TestValidateAndGetApiToken: validate_and_get_api_token("app") assert "Authorization header must be provided" in str(exc_info.value) - def test_invalid_auth_scheme(self, app): + def test_invalid_auth_scheme(self, app: Flask): """Test that Unauthorized is raised when auth scheme is not Bearer.""" # Arrange with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}): @@ -62,7 +62,7 @@ class TestValidateAndGetApiToken: @patch("controllers.service_api.wraps.record_token_usage") @patch("controllers.service_api.wraps.ApiTokenCache") @patch("controllers.service_api.wraps.fetch_token_with_single_flight") - def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app): + def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask): """Test that valid token returns the ApiToken object.""" # Arrange mock_api_token = Mock(spec=ApiToken) @@ -84,7 +84,7 @@ class TestValidateAndGetApiToken: @patch("controllers.service_api.wraps.record_token_usage") @patch("controllers.service_api.wraps.ApiTokenCache") @patch("controllers.service_api.wraps.fetch_token_with_single_flight") - def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app): + def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask): """Test that invalid token raises Unauthorized.""" # Arrange from werkzeug.exceptions import Unauthorized @@ -161,7 +161,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app no longer exists.""" # Arrange mock_api_token = Mock() @@ -182,7 +182,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app status is abnormal.""" # Arrange mock_api_token = Mock() @@ -205,7 +205,7 @@ class TestValidateAppToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app): + def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app: Flask): """Test that Forbidden is raised when app API is disabled.""" # Arrange mock_api_token = Mock() @@ -240,7 +240,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app): + def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app: Flask): """Test that request is allowed when under resource limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -264,7 +264,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app): + def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app: Flask): """Test that Forbidden is raised when at resource limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -287,7 +287,7 @@ class TestCloudEditionBillingResourceCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app): + def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app: Flask): """Test that request is allowed when billing is disabled.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -320,7 +320,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app): + def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask): """Test that add_segment is rejected in SANDBOX plan.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -342,7 +342,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_features") - def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app): + def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask): """Test that non-add_segment operations are allowed in SANDBOX.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -376,7 +376,7 @@ class TestCloudEditionBillingRateLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit") - def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app): + def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app: Flask): """Test that request is allowed when within rate limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -406,7 +406,7 @@ class TestCloudEditionBillingRateLimitCheck: @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit") @patch("controllers.service_api.wraps.db") - def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app): + def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app: Flask): """Test that Forbidden is raised when over rate limit.""" # Arrange mock_validate_token.return_value = Mock(tenant_id="tenant123") @@ -445,7 +445,7 @@ class TestValidateDatasetToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") @patch("controllers.service_api.wraps.current_app") - def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app): + def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask): """Test that valid dataset token allows access.""" # Arrange # Use standard Mock for login_manager @@ -487,7 +487,7 @@ class TestValidateDatasetToken: @patch("controllers.service_api.wraps.db") @patch("controllers.service_api.wraps.validate_and_get_api_token") - def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app): + def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app: Flask): """Test that NotFound is raised when dataset doesn't exist.""" # Arrange mock_api_token = Mock() diff --git a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py index 86b461cf04..c1c1291281 100644 --- a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py +++ b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py @@ -13,7 +13,7 @@ class TestExternalDataFetch: app = Flask(__name__) return app - def test_fetch_success(self, app): + def test_fetch_success(self, app: Flask): with app.app_context(): fetcher = ExternalDataFetch() @@ -79,7 +79,7 @@ class TestExternalDataFetch: assert result_inputs == inputs assert result_inputs is not inputs # Should be a copy - def test_fetch_with_none_variable(self, app): + def test_fetch_with_none_variable(self, app: Flask): with app.app_context(): fetcher = ExternalDataFetch() tool = ExternalDataVariableEntity(variable="var1", type="type1", config={}) @@ -95,7 +95,7 @@ class TestExternalDataFetch: assert "var1" not in result_inputs assert result_inputs == {"in": "val"} - def test_query_external_data_tool(self, app): + def test_query_external_data_tool(self, app: Flask): fetcher = ExternalDataFetch() tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"}) diff --git a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py b/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py deleted file mode 100644 index 53a9e6210c..0000000000 --- a/api/tests/unit_tests/services/plugin/test_plugin_permission_service.py +++ /dev/null @@ -1,79 +0,0 @@ -from unittest.mock import MagicMock, patch - -from models.account import TenantPluginPermission - -MODULE = "services.plugin.plugin_permission_service" - - -def _patched_session(): - """Patch session_factory.create_session() to return a mock session as context manager.""" - session = MagicMock() - session.__enter__ = MagicMock(return_value=session) - session.__exit__ = MagicMock(return_value=False) - session.begin.return_value.__enter__ = MagicMock(return_value=session) - session.begin.return_value.__exit__ = MagicMock(return_value=False) - mock_factory = MagicMock() - mock_factory.create_session.return_value = session - patcher = patch(f"{MODULE}.session_factory", mock_factory) - return patcher, session - - -class TestGetPermission: - def test_returns_permission_when_found(self): - p1, session = _patched_session() - permission = MagicMock() - session.scalar.return_value = permission - - with p1: - from services.plugin.plugin_permission_service import PluginPermissionService - - result = PluginPermissionService.get_permission("t1") - - assert result is permission - - def test_returns_none_when_not_found(self): - p1, session = _patched_session() - session.scalar.return_value = None - - with p1: - from services.plugin.plugin_permission_service import PluginPermissionService - - result = PluginPermissionService.get_permission("t1") - - assert result is None - - -class TestChangePermission: - def test_creates_new_permission_when_not_exists(self): - p1, session = _patched_session() - session.scalar.return_value = None - - with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls: - perm_cls.return_value = MagicMock() - from services.plugin.plugin_permission_service import PluginPermissionService - - result = PluginPermissionService.change_permission( - "t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE - ) - - assert result is True - session.begin.assert_called_once() - session.add.assert_called_once() - - def test_updates_existing_permission(self): - p1, session = _patched_session() - existing = MagicMock() - session.scalar.return_value = existing - - with p1: - from services.plugin.plugin_permission_service import PluginPermissionService - - result = PluginPermissionService.change_permission( - "t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS - ) - - assert result is True - session.begin.assert_called_once() - assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS - assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS - session.add.assert_not_called() diff --git a/docker/.env.default b/docker/.env.default new file mode 100644 index 0000000000..6f6683b9f5 --- /dev/null +++ b/docker/.env.default @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------ +# Minimal defaults for Docker Compose deployments. +# +# Keep local changes in .env. Use .env.example as the full reference +# for advanced and service-specific settings. +# ------------------------------------------------------------------ + +# Public URLs used when Dify generates links. Change these together when +# exposing Dify under another hostname, IP address, or port. +CONSOLE_WEB_URL=http://localhost +SERVICE_API_URL=http://localhost +APP_WEB_URL=http://localhost +FILES_URL=http://localhost +INTERNAL_FILES_URL=http://api:5001 +TRIGGER_URL=http://localhost +ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} +NEXT_PUBLIC_SOCKET_URL=ws://localhost +EXPOSE_PLUGIN_DEBUGGING_HOST=localhost +EXPOSE_PLUGIN_DEBUGGING_PORT=5003 + +# Built-in metadata database defaults. +DB_TYPE=postgresql +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=db_postgres +DB_PORT=5432 +DB_DATABASE=dify + +# Built-in Redis defaults. +REDIS_HOST=redis +REDIS_PORT=6379 +REDIS_PASSWORD=difyai123456 + +# Default file storage. +STORAGE_TYPE=opendal +OPENDAL_SCHEME=fs +OPENDAL_FS_ROOT=storage + +# Default vector database. +VECTOR_STORE=weaviate + +# Internal service authentication. Paired values must match. +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DIFY_INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 + +# Host ports. +EXPOSE_NGINX_PORT=80 +EXPOSE_NGINX_SSL_PORT=443 + +# Docker Compose profiles for bundled services. +COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql} diff --git a/docker/.env.example b/docker/.env.example index 29741474fa..122228cdd1 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1003,7 +1003,7 @@ NOTION_INTERNAL_SECRET= # ------------------------------ # Mail type, support: resend, smtp, sendgrid -MAIL_TYPE=resend +MAIL_TYPE= # Default send from email address, if not specified # If using SendGrid, use the 'from' field for authentication if necessary. @@ -1011,7 +1011,7 @@ MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. RESEND_API_URL=https://api.resend.com -RESEND_API_KEY=your-resend-api-key +RESEND_API_KEY= # SMTP server configuration, used when MAIL_TYPE is `smtp` @@ -1359,10 +1359,10 @@ NGINX_ENABLE_CERTBOT_CHALLENGE=false # ------------------------------ # Email address (required to get certificates from Let's Encrypt) -CERTBOT_EMAIL=your_email@example.com +CERTBOT_EMAIL= # Domain name -CERTBOT_DOMAIN=your_domain.com +CERTBOT_DOMAIN= # certbot command options # i.e: --force-renewal --dry-run --test-cert --debug diff --git a/docker/README.md b/docker/README.md index 3130fa9886..3a7f4c2ad5 100644 --- a/docker/README.md +++ b/docker/README.md @@ -7,28 +7,28 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T - **Certbot Container**: `docker-compose.yaml` now contains `certbot` for managing SSL certificates. This container automatically renews certificates and ensures secure HTTPS connections.\ For more information, refer `docker/certbot/README.md`. -- **Persistent Environment Variables**: Environment variables are now managed through a `.env` file, ensuring that your configurations persist across deployments. +- **Persistent Environment Variables**: Default environment variables are managed through `.env.default`, while local overrides are stored in `.env`, ensuring that your configurations persist across deployments. > What is `.env`?

- > The `.env` file is a crucial component in Docker and Docker Compose environments, serving as a centralized configuration file where you can define environment variables that are accessible to the containers at runtime. This file simplifies the management of environment settings across different stages of development, testing, and production, providing consistency and ease of configuration to deployments. + > The `.env` file is a local override file. Keep it small by adding only the values that differ from `.env.default`. Use `.env.example` as the full reference when you need advanced configuration. - **Unified Vector Database Services**: All vector database services are now managed from a single Docker Compose file `docker-compose.yaml`. You can switch between different vector databases by setting the `VECTOR_STORE` environment variable in your `.env` file. -- **Mandatory .env File**: A `.env` file is now required to run `docker compose up`. This file is crucial for configuring your deployment and for any custom settings to persist through upgrades. +- **Local .env Overrides**: The `dify-compose` and `dify-compose.ps1` wrappers create `.env` if it is missing and generate a persistent `SECRET_KEY` for this deployment. ### How to Deploy Dify with `docker-compose.yaml` 1. **Prerequisites**: Ensure Docker and Docker Compose are installed on your system. 1. **Environment Setup**: - Navigate to the `docker` directory. - - Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`. - - Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options. - - **Optional (Recommended for upgrades)**: - You may use the environment synchronization tool to help keep your `.env` file aligned with the latest `.env.example` updates, while preserving your custom settings. - This is especially useful when upgrading Dify or managing a large, customized `.env` file. + - No copy step is required. The `dify-compose` wrappers create `.env` if it is missing and write a generated `SECRET_KEY` to it. + - When prompted on first run, press Enter to use the default deployment, or answer `y` to stop and edit `.env` first. + - Customize `.env` only when you need to override defaults from `.env.default`. Refer to `.env.example` for the full list of available variables. + - **Optional (for advanced deployments)**: + If you maintain a full `.env` file copied from `.env.example`, you may use the environment synchronization tool to keep it aligned with the latest `.env.example` updates while preserving your custom settings. See the [Environment Variables Synchronization](#environment-variables-synchronization) section below. 1. **Running the Services**: - - Execute `docker compose up` from the `docker` directory to start the services. + - Execute `./dify-compose up -d` from the `docker` directory to start the services. On Windows PowerShell, run `.\dify-compose.ps1 up -d`. - To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`. 1. **SSL Certificate Setup**: - Refer `docker/certbot/README.md` to set up SSL certificates using Certbot. @@ -58,7 +58,13 @@ For users migrating from the `docker-legacy` setup: 1. **Data Migration**: - Ensure that data from services like databases and caches is backed up and migrated appropriately to the new structure if necessary. -### Overview of `.env` +### Overview of `.env.default`, `.env`, and `.env.example` + +- `.env.default` contains the minimal default configuration for Docker Compose deployments. +- `.env` contains the generated `SECRET_KEY` plus any local overrides. +- `.env.example` is the full reference for advanced configuration. + +The `dify-compose` wrappers merge `.env.default` and `.env` into a temporary environment file, append paired internal service keys when needed, and remove the temporary file after Docker Compose starts. #### Key Modules and Customization @@ -118,9 +124,11 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w ### Environment Variables Synchronization -When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.example`. +When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.default` or `.env.example`. -To help keep your existing `.env` file up to date **without losing your custom values**, an optional environment variables synchronization tool is provided. +If you use the default override-only workflow, review `.env.default` and add only the values you need to override to `.env`. + +If you maintain a full `.env` file copied from `.env.example`, an optional environment variables synchronization tool is provided. > This tool performs a **one-way synchronization** from `.env.example` to `.env`. > Existing values in `.env` are never overwritten automatically. @@ -143,9 +151,9 @@ Before synchronization, the current `.env` file is saved to the `env-backup/` di **When to use** -- After upgrading Dify to a newer version +- After upgrading Dify to a newer version with a full `.env` file - When `.env.example` has been updated with new environment variables -- When managing a large or heavily customized `.env` file +- When managing a large or heavily customized `.env` file copied from `.env.example` **Usage** diff --git a/docker/dify-compose b/docker/dify-compose new file mode 100755 index 0000000000..16bbd6b538 --- /dev/null +++ b/docker/dify-compose @@ -0,0 +1,334 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +DEFAULT_ENV_FILE=".env.default" +USER_ENV_FILE=".env" + +log() { + printf '%s\n' "$*" >&2 +} + +die() { + printf 'Error: %s\n' "$*" >&2 + exit 1 +} + +detect_compose() { + if docker compose version >/dev/null 2>&1; then + COMPOSE_CMD=(docker compose) + return + fi + + if command -v docker-compose >/dev/null 2>&1; then + COMPOSE_CMD=(docker-compose) + return + fi + + die "Docker Compose is not available. Install Docker Compose, then run this command again." +} + +generate_secret_key() { + if command -v openssl >/dev/null 2>&1; then + openssl rand -base64 42 + return + fi + + if command -v dd >/dev/null 2>&1 && command -v base64 >/dev/null 2>&1; then + dd if=/dev/urandom bs=42 count=1 2>/dev/null | base64 | tr -d '\n' + printf '\n' + return + fi + + return 1 +} + +ensure_env_files() { + [[ -f "$DEFAULT_ENV_FILE" ]] || die "$DEFAULT_ENV_FILE is missing." + + if [[ -f "$USER_ENV_FILE" ]]; then + return + fi + + : >"$USER_ENV_FILE" + + if [[ ! -t 0 ]]; then + log "Created $USER_ENV_FILE for local overrides." + return + fi + + printf 'Created %s for local overrides.\n' "$USER_ENV_FILE" + printf 'Do you need a custom deployment now? (Most users can press Enter to skip.) [y/N] ' + read -r answer + + case "${answer:-}" in + y | Y | yes | YES | Yes) + cat <<'EOF' +Edit .env with the settings you want to override, using .env.example as the full reference. +Run ./dify-compose up -d again when you are ready. +EOF + exit 0 + ;; + esac +} + +user_env_value() { + local key="$1" + awk -F= -v target="$key" ' + /^[[:space:]]*#/ || !/=/{ next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + value = substr($0, index($0, "=") + 1) + gsub(/^[[:space:]]+|[[:space:]]+$/, "", value) + if ((value ~ /^".*"$/) || (value ~ /^'\''.*'\''$/)) { + value = substr(value, 2, length(value) - 2) + } + result = value + } + } + END { print result } + ' "$USER_ENV_FILE" +} + +set_user_env_value() { + local key="$1" + local value="$2" + local temp_file + + temp_file="$(mktemp "${TMPDIR:-/tmp}/dify-env.XXXXXX")" + awk -F= -v target="$key" -v replacement="$key=$value" ' + BEGIN { replaced = 0 } + /^[[:space:]]*#/ || !/=/{ print; next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + if (!replaced) { + print replacement + replaced = 1 + } + next + } + print + } + END { + if (!replaced) { + print replacement + } + } + ' "$USER_ENV_FILE" >"$temp_file" + mv "$temp_file" "$USER_ENV_FILE" +} + +ensure_secret_key() { + local current_secret_key + local secret_key + + current_secret_key="$(user_env_value SECRET_KEY)" + if [[ -n "$current_secret_key" ]]; then + return + fi + + secret_key="$(generate_secret_key)" || die "Unable to generate SECRET_KEY. Install openssl or configure SECRET_KEY in .env." + set_user_env_value SECRET_KEY "$secret_key" + log "Generated SECRET_KEY in $USER_ENV_FILE." +} + +env_value() { + local key="$1" + awk -F= -v target="$key" ' + /^[[:space:]]*#/ || !/=/{ next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + value = substr($0, index($0, "=") + 1) + gsub(/^[[:space:]]+|[[:space:]]+$/, "", value) + if ((value ~ /^".*"$/) || (value ~ /^'\''.*'\''$/)) { + value = substr(value, 2, length(value) - 2) + } + result = value + } + } + END { print result } + ' "$DEFAULT_ENV_FILE" "$USER_ENV_FILE" +} + +user_overrides() { + local key="$1" + grep -Eq "^[[:space:]]*${key}[[:space:]]*=" "$USER_ENV_FILE" +} + +write_merged_env() { + awk ' + function trim(s) { + sub(/^[[:space:]]+/, "", s) + sub(/[[:space:]]+$/, "", s) + return s + } + + /^[[:space:]]*#/ || !/=/{ next } + + { + key = $0 + sub(/=.*/, "", key) + key = trim(key) + if (key == "") { + next + } + + value = substr($0, index($0, "=") + 1) + value = trim(value) + + if (!(key in seen)) { + order[++count] = key + seen[key] = 1 + } + + values[key] = value + } + + END { + for (i = 1; i <= count; i++) { + key = order[i] + print key "=" values[key] + } + } + ' "$DEFAULT_ENV_FILE" "$USER_ENV_FILE" >"$MERGED_ENV_FILE" +} + +set_merged_env_value() { + local key="$1" + local value="$2" + local temp_file + + temp_file="$(mktemp "${TMPDIR:-/tmp}/dify-compose-env.XXXXXX")" + awk -F= -v target="$key" -v replacement="$key=$value" ' + BEGIN { replaced = 0 } + /^[[:space:]]*#/ || !/=/{ print; next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + if (!replaced) { + print replacement + replaced = 1 + } + next + } + print + } + END { + if (!replaced) { + print replacement + } + } + ' "$MERGED_ENV_FILE" >"$temp_file" + mv "$temp_file" "$MERGED_ENV_FILE" +} + +set_if_not_overridden() { + local key="$1" + local value="$2" + + if user_overrides "$key"; then + return + fi + + set_merged_env_value "$key" "$value" +} + +metadata_db_host() { + case "$1" in + mysql) printf 'db_mysql' ;; + postgresql | '') printf 'db_postgres' ;; + *) printf '%s' "$(env_value DB_HOST)" ;; + esac +} + +metadata_db_port() { + case "$1" in + mysql) printf '3306' ;; + postgresql | '') printf '5432' ;; + *) printf '%s' "$(env_value DB_PORT)" ;; + esac +} + +metadata_db_user() { + case "$1" in + mysql) printf 'root' ;; + postgresql | '') printf 'postgres' ;; + *) printf '%s' "$(env_value DB_USERNAME)" ;; + esac +} + +build_merged_env() { + MERGED_ENV_FILE="$(mktemp "${TMPDIR:-/tmp}/dify-compose.XXXXXX")" + trap 'rm -f "$MERGED_ENV_FILE"' EXIT + + write_merged_env + + local db_type + local redis_host + local redis_port + local redis_username + local redis_password + local redis_auth + local code_execution_api_key + local weaviate_api_key + + db_type="$(env_value DB_TYPE)" + + set_if_not_overridden DB_HOST "$(metadata_db_host "$db_type")" + set_if_not_overridden DB_PORT "$(metadata_db_port "$db_type")" + set_if_not_overridden DB_USERNAME "$(metadata_db_user "$db_type")" + + if ! user_overrides CELERY_BROKER_URL; then + redis_host="$(env_value REDIS_HOST)" + redis_port="$(env_value REDIS_PORT)" + redis_username="$(env_value REDIS_USERNAME)" + redis_password="$(env_value REDIS_PASSWORD)" + redis_auth="" + + if [[ -n "$redis_username" && -n "$redis_password" ]]; then + redis_auth="${redis_username}:${redis_password}@" + elif [[ -n "$redis_password" ]]; then + redis_auth=":${redis_password}@" + elif [[ -n "$redis_username" ]]; then + redis_auth="${redis_username}@" + fi + + set_merged_env_value CELERY_BROKER_URL "redis://${redis_auth}${redis_host:-redis}:${redis_port:-6379}/1" + fi + + if ! user_overrides SANDBOX_API_KEY; then + code_execution_api_key="$(env_value CODE_EXECUTION_API_KEY)" + set_if_not_overridden SANDBOX_API_KEY "${code_execution_api_key:-dify-sandbox}" + fi + + if ! user_overrides WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS; then + weaviate_api_key="$(env_value WEAVIATE_API_KEY)" + set_if_not_overridden WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS \ + "${weaviate_api_key:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}" + fi +} + +main() { + detect_compose + ensure_env_files + ensure_secret_key + build_merged_env + + if [[ "$#" -eq 0 ]]; then + set -- up -d + fi + + "${COMPOSE_CMD[@]}" --env-file "$MERGED_ENV_FILE" "$@" +} + +main "$@" diff --git a/docker/dify-compose.ps1 b/docker/dify-compose.ps1 new file mode 100644 index 0000000000..851f8b76fe --- /dev/null +++ b/docker/dify-compose.ps1 @@ -0,0 +1,317 @@ +$ErrorActionPreference = "Stop" +Set-StrictMode -Version Latest + +$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path +Set-Location $ScriptDir + +$DefaultEnvFile = ".env.default" +$UserEnvFile = ".env" +$MergedEnvFile = $null +$Utf8NoBom = New-Object System.Text.UTF8Encoding -ArgumentList $false + +function Write-Info { + param([string]$Message) + [Console]::Error.WriteLine($Message) +} + +function Fail { + param([string]$Message) + [Console]::Error.WriteLine("Error: $Message") + exit 1 +} + +function Test-CommandSuccess { + param([string[]]$Command) + + try { + $Executable = $Command[0] + $CommandArgs = @() + if ($Command.Length -gt 1) { + $CommandArgs = @($Command[1..($Command.Length - 1)]) + } + + & $Executable @CommandArgs *> $null + return $LASTEXITCODE -eq 0 + } + catch { + return $false + } +} + +function Get-ComposeCommand { + if (Test-CommandSuccess @("docker", "compose", "version")) { + return @("docker", "compose") + } + + if ((Get-Command "docker-compose" -ErrorAction SilentlyContinue) -and (Test-CommandSuccess @("docker-compose", "version"))) { + return @("docker-compose") + } + + Fail "Docker Compose is not available. Install Docker Compose, then run this command again." +} + +function New-SecretKey { + $Bytes = New-Object byte[] 42 + $Generator = [System.Security.Cryptography.RandomNumberGenerator]::Create() + + try { + $Generator.GetBytes($Bytes) + } + finally { + $Generator.Dispose() + } + + return [Convert]::ToBase64String($Bytes) +} + +function Ensure-EnvFiles { + if (-not (Test-Path $DefaultEnvFile -PathType Leaf)) { + Fail "$DefaultEnvFile is missing." + } + + if (Test-Path $UserEnvFile -PathType Leaf) { + return + } + + New-Item -ItemType File -Path $UserEnvFile | Out-Null + + if ([Console]::IsInputRedirected) { + Write-Info "Created $UserEnvFile for local overrides." + return + } + + Write-Info "Created $UserEnvFile for local overrides." + $Answer = Read-Host "Do you need a custom deployment now? (Most users can press Enter to skip.) [y/N]" + + if ($Answer -match "^(y|yes)$") { + Write-Output "Edit .env with the settings you want to override, using .env.example as the full reference." + Write-Output "Run .\dify-compose.ps1 up -d again when you are ready." + exit 0 + } +} + +function Read-EnvFile { + param([string]$Path) + + $Values = [ordered]@{} + + if (-not (Test-Path $Path -PathType Leaf)) { + return $Values + } + + foreach ($Line in Get-Content -Path $Path) { + if ($Line -match "^\s*#" -or $Line -notmatch "=") { + continue + } + + $SeparatorIndex = $Line.IndexOf("=") + $Key = $Line.Substring(0, $SeparatorIndex).Trim() + $Value = $Line.Substring($SeparatorIndex + 1).Trim() + + if (($Value.StartsWith('"') -and $Value.EndsWith('"')) -or ($Value.StartsWith("'") -and $Value.EndsWith("'"))) { + $Value = $Value.Substring(1, $Value.Length - 2) + } + + if ($Key.Length -gt 0) { + $Values[$Key] = $Value + } + } + + return $Values +} + +function Set-UserEnvValue { + param( + [string]$Key, + [string]$Value + ) + + $Path = [string](Resolve-Path $UserEnvFile) + $Lines = [System.IO.File]::ReadAllLines($Path, [System.Text.Encoding]::UTF8) + $Output = New-Object System.Collections.Generic.List[string] + $Replaced = $false + + foreach ($Line in $Lines) { + if ($Line -match "^\s*#" -or $Line -notmatch "=") { + $Output.Add($Line) + continue + } + + $SeparatorIndex = $Line.IndexOf("=") + $CurrentKey = $Line.Substring(0, $SeparatorIndex).Trim() + + if ($CurrentKey -eq $Key) { + if (-not $Replaced) { + $Output.Add("$Key=$Value") + $Replaced = $true + } + continue + } + + $Output.Add($Line) + } + + if (-not $Replaced) { + $Output.Add("$Key=$Value") + } + + [System.IO.File]::WriteAllLines($Path, $Output, $Utf8NoBom) +} + +function Ensure-SecretKey { + $Values = Read-EnvFile $UserEnvFile + + if ($Values.Contains("SECRET_KEY") -and $Values["SECRET_KEY"]) { + return + } + + Set-UserEnvValue "SECRET_KEY" (New-SecretKey) + Write-Info "Generated SECRET_KEY in $UserEnvFile." +} + +function Merge-EnvValues { + $Values = [ordered]@{} + + foreach ($Entry in (Read-EnvFile $DefaultEnvFile).GetEnumerator()) { + $Values[$Entry.Key] = $Entry.Value + } + + foreach ($Entry in (Read-EnvFile $UserEnvFile).GetEnumerator()) { + $Values[$Entry.Key] = $Entry.Value + } + + return $Values +} + +function User-Overrides { + param([string]$Key) + + if (-not (Test-Path $UserEnvFile -PathType Leaf)) { + return $false + } + + return [bool](Select-String -Path $UserEnvFile -Pattern "^\s*$([regex]::Escape($Key))\s*=" -Quiet) +} + +function Metadata-DbHost { + param([string]$DbType, $Values) + + switch ($DbType) { + "mysql" { return "db_mysql" } + "postgresql" { return "db_postgres" } + "" { return "db_postgres" } + default { return $Values["DB_HOST"] } + } +} + +function Metadata-DbPort { + param([string]$DbType, $Values) + + switch ($DbType) { + "mysql" { return "3306" } + "postgresql" { return "5432" } + "" { return "5432" } + default { return $Values["DB_PORT"] } + } +} + +function Metadata-DbUser { + param([string]$DbType, $Values) + + switch ($DbType) { + "mysql" { return "root" } + "postgresql" { return "postgres" } + "" { return "postgres" } + default { return $Values["DB_USERNAME"] } + } +} + +function Write-MergedEnv { + param($Values) + + $Output = New-Object System.Collections.Generic.List[string] + + foreach ($Entry in $Values.GetEnumerator()) { + $Output.Add("$($Entry.Key)=$($Entry.Value)") + } + + [System.IO.File]::WriteAllLines($MergedEnvFile, $Output, $Utf8NoBom) +} + +function Build-MergedEnv { + $Values = Merge-EnvValues + $script:MergedEnvFile = [System.IO.Path]::GetTempFileName() + + $DbType = if ($Values.Contains("DB_TYPE")) { $Values["DB_TYPE"] } else { "postgresql" } + + if (-not (User-Overrides "DB_HOST")) { + $Values["DB_HOST"] = Metadata-DbHost $DbType $Values + } + + if (-not (User-Overrides "DB_PORT")) { + $Values["DB_PORT"] = Metadata-DbPort $DbType $Values + } + + if (-not (User-Overrides "DB_USERNAME")) { + $Values["DB_USERNAME"] = Metadata-DbUser $DbType $Values + } + + if (-not (User-Overrides "CELERY_BROKER_URL")) { + $RedisHost = if ($Values.Contains("REDIS_HOST") -and $Values["REDIS_HOST"]) { $Values["REDIS_HOST"] } else { "redis" } + $RedisPort = if ($Values.Contains("REDIS_PORT") -and $Values["REDIS_PORT"]) { $Values["REDIS_PORT"] } else { "6379" } + $RedisUsername = if ($Values.Contains("REDIS_USERNAME")) { $Values["REDIS_USERNAME"] } else { "" } + $RedisPassword = if ($Values.Contains("REDIS_PASSWORD")) { $Values["REDIS_PASSWORD"] } else { "" } + $RedisAuth = "" + + if ($RedisUsername -and $RedisPassword) { + $RedisAuth = "${RedisUsername}:${RedisPassword}@" + } + elseif ($RedisPassword) { + $RedisAuth = ":${RedisPassword}@" + } + elseif ($RedisUsername) { + $RedisAuth = "${RedisUsername}@" + } + + $Values["CELERY_BROKER_URL"] = "redis://$RedisAuth${RedisHost}:${RedisPort}/1" + } + + if (-not (User-Overrides "SANDBOX_API_KEY")) { + $CodeExecutionApiKey = if ($Values.Contains("CODE_EXECUTION_API_KEY") -and $Values["CODE_EXECUTION_API_KEY"]) { $Values["CODE_EXECUTION_API_KEY"] } else { "dify-sandbox" } + $Values["SANDBOX_API_KEY"] = $CodeExecutionApiKey + } + + if (-not (User-Overrides "WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS")) { + $WeaviateApiKey = if ($Values.Contains("WEAVIATE_API_KEY") -and $Values["WEAVIATE_API_KEY"]) { $Values["WEAVIATE_API_KEY"] } else { "WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih" } + $Values["WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS"] = $WeaviateApiKey + } + + Write-MergedEnv $Values +} + +$ComposeCommand = Get-ComposeCommand + +try { + Ensure-EnvFiles + Ensure-SecretKey + Build-MergedEnv + + $ComposeArgs = @($args) + if ($ComposeArgs.Count -eq 0) { + $ComposeArgs = @("up", "-d") + } + + $ComposeCommandArgs = @() + if ($ComposeCommand.Length -gt 1) { + $ComposeCommandArgs = @($ComposeCommand[1..($ComposeCommand.Length - 1)]) + } + + $ComposeExecutable = $ComposeCommand[0] + & $ComposeExecutable @ComposeCommandArgs --env-file $MergedEnvFile @ComposeArgs + exit $LASTEXITCODE +} +finally { + if ($MergedEnvFile -and (Test-Path $MergedEnvFile -PathType Leaf)) { + Remove-Item -Force $MergedEnvFile + } +} diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 87fa01f671..b2df61ebb2 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -170,8 +170,8 @@ services: ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} - TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} @@ -402,8 +402,8 @@ services: - ./certbot/update-cert.template.txt:/update-cert.template.txt - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh environment: - - CERTBOT_EMAIL=${CERTBOT_EMAIL} - - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_EMAIL=${CERTBOT_EMAIL:-} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} entrypoint: ["/docker-entrypoint.sh"] command: ["tail", "-f", "/dev/null"] diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a72136049d..6dcab4a9fc 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -441,10 +441,10 @@ x-shared-env: &shared-api-worker-env NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-} NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-} NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} - MAIL_TYPE: ${MAIL_TYPE:-resend} + MAIL_TYPE: ${MAIL_TYPE:-} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} - RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} + RESEND_API_KEY: ${RESEND_API_KEY:-} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} SMTP_USERNAME: ${SMTP_USERNAME:-} @@ -586,8 +586,8 @@ x-shared-env: &shared-api-worker-env NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} - CERTBOT_EMAIL: ${CERTBOT_EMAIL:-your_email@example.com} - CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-your_domain.com} + CERTBOT_EMAIL: ${CERTBOT_EMAIL:-} + CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-} CERTBOT_OPTIONS: ${CERTBOT_OPTIONS:-} SSRF_HTTP_PORT: ${SSRF_HTTP_PORT:-3128} SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} @@ -894,8 +894,8 @@ services: ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} - TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} @@ -1126,8 +1126,8 @@ services: - ./certbot/update-cert.template.txt:/update-cert.template.txt - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh environment: - - CERTBOT_EMAIL=${CERTBOT_EMAIL} - - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_EMAIL=${CERTBOT_EMAIL:-} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} entrypoint: ["/docker-entrypoint.sh"] command: ["tail", "-f", "/dev/null"] diff --git a/eslint-suppressions.json b/eslint-suppressions.json index b4876dcf45..b5e67df509 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -315,20 +315,12 @@ "count": 4 } }, - "web/app/components/app/configuration/config-var/config-modal/type-select.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/configuration/config-var/index.tsx": { "no-restricted-imports": { "count": 1 } }, "web/app/components/app/configuration/config-var/select-var-type.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -363,9 +355,6 @@ } }, "web/app/components/app/configuration/config/assistant-type-picker/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } @@ -401,11 +390,6 @@ "count": 1 } }, - "web/app/components/app/configuration/config/automatic/version-selector.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx": { "no-restricted-imports": { "count": 1 @@ -774,11 +758,6 @@ "count": 1 } }, - "web/app/components/base/chat/chat-with-history/sidebar/rename-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/base/chat/chat/answer/agent-content.tsx": { "style/multiline-ternary": { "count": 2 @@ -800,11 +779,6 @@ "count": 1 } }, - "web/app/components/base/chat/chat/answer/operation.tsx": { - "no-restricted-imports": { - "count": 2 - } - }, "web/app/components/base/chat/chat/answer/workflow-process.tsx": { "react/set-state-in-effect": { "count": 1 @@ -1055,14 +1029,6 @@ "count": 3 } }, - "web/app/components/base/form/components/base/base-field.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 3 - } - }, "web/app/components/base/form/components/base/base-form.tsx": { "ts/no-explicit-any": { "count": 6 @@ -1589,14 +1555,6 @@ "count": 1 } }, - "web/app/components/base/modal/modal.stories.tsx": { - "no-console": { - "count": 4 - }, - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/base/new-audio-button/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2116,11 +2074,6 @@ "count": 1 } }, - "web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2389,11 +2342,6 @@ "count": 1 } }, - "web/app/components/datasets/settings/index-method/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/develop/code.tsx": { "ts/no-empty-object-type": { "count": 1 @@ -2579,11 +2527,8 @@ "erasable-syntax-only/enums": { "count": 1 }, - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { - "count": 3 + "count": 2 } }, "web/app/components/header/account-setting/model-provider-page/declarations.ts": { @@ -2612,11 +2557,6 @@ "count": 1 } }, - "web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts": { "no-barrel-files/no-barrel-files": { "count": 6 @@ -2912,44 +2852,11 @@ "count": 1 } }, - "web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts": { - "erasable-syntax-only/enums": { - "count": 2 - } - }, - "web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { - "no-barrel-files/no-barrel-files": { - "count": 3 - } - }, - "web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts": { "erasable-syntax-only/enums": { "count": 1 } }, - "web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx": { - "erasable-syntax-only/enums": { - "count": 1 - }, - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": { "no-barrel-files/no-barrel-files": { "count": 2 @@ -2978,11 +2885,6 @@ "count": 7 } }, - "web/app/components/plugins/plugin-detail-panel/tool-selector/components/schema-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-item.tsx": { "no-restricted-imports": { "count": 1 @@ -2998,11 +2900,6 @@ "count": 5 } }, - "web/app/components/plugins/plugin-item/action.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/plugins/plugin-item/index.tsx": { "no-restricted-imports": { "count": 1 @@ -3681,11 +3578,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/_base/components/error-handle/types.ts": { "erasable-syntax-only/enums": { "count": 1 @@ -3782,11 +3674,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/_base/components/variable/var-type-picker.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts": { "react/no-unnecessary-use-prefix": { "count": 2 @@ -4106,16 +3993,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/if-else/components/condition-number-input.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/if-else/default.ts": { "ts/no-explicit-any": { "count": 1 @@ -4151,11 +4028,6 @@ "count": 6 } }, - "web/app/components/workflow/nodes/knowledge-base/components/chunk-structure/selector.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/hooks.tsx": { "ts/no-explicit-any": { "count": 4 @@ -4199,11 +4071,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/metadata-filter-selector.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-retrieval/default.ts": { "ts/no-explicit-any": { "count": 1 @@ -4285,11 +4152,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/type-selector.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts": { "ts/no-explicit-any": { "count": 1 @@ -4341,16 +4203,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/loop/components/condition-list/condition-operator.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/workflow/nodes/loop/components/condition-number-input.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx": { "ts/no-explicit-any": { "count": 3 @@ -4522,9 +4374,6 @@ } }, "web/app/components/workflow/nodes/tool/components/tool-form/item.tsx": { - "no-restricted-imports": { - "count": 1 - }, "ts/no-explicit-any": { "count": 1 } diff --git a/eslint.config.mjs b/eslint.config.mjs index ae9fdaff01..1380ed67d2 100644 --- a/eslint.config.mjs +++ b/eslint.config.mjs @@ -4,6 +4,17 @@ import antfu, { GLOB_MARKDOWN } from '@antfu/eslint-config' import md from 'eslint-markdown' import markdownPreferences from 'eslint-plugin-markdown-preferences' +const GENERATED_IGNORES = [ + '**/storybook-static/', + '**/.next/', + 'web/next/', + 'web/next-env.d.ts', + '**/dist/', + '**/coverage/', + 'e2e/.auth/', + 'e2e/cucumber-report/', +] + export default antfu( { ignores: original => [ @@ -15,6 +26,7 @@ export default antfu( '!package.json', '!pnpm-workspace.yaml', '!vite.config.ts', + ...GENERATED_IGNORES, ...original, ], typescript: { diff --git a/package.json b/package.json index a563b574f7..baef89c4e8 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "dify", "type": "module", "private": true, - "packageManager": "pnpm@11.0.0", + "packageManager": "pnpm@11.0.6", "engines": { "node": "^22.22.1" }, diff --git a/packages/dify-ui/AGENTS.md b/packages/dify-ui/AGENTS.md index d8a59b7a0b..bdc2160702 100644 --- a/packages/dify-ui/AGENTS.md +++ b/packages/dify-ui/AGENTS.md @@ -56,4 +56,28 @@ The Figma design system uses `--radius/*` tokens whose scale is **offset by one - When the Figma MCP returns `rounded-[var(--radius/sm, 6px)]`, convert it to the standard Tailwind class from the table above (e.g. `rounded-md`). - For values without a standard Tailwind equivalent (10px, 20px, 28px), use arbitrary values like `rounded-[10px]`. +## Search / Picker Primitive Selection: Autocomplete vs Combobox vs Select + +Pick by whether the user is entering free-form text, choosing a remembered value, or selecting from a closed list. + +Base UI decision rules: + +- [Autocomplete docs]: use `Combobox` instead of `Autocomplete` if the selection should be remembered and the input value cannot be custom. +- [Combobox docs]: do not use `Combobox` for simple search widgets that require unrestricted text entry; use `Autocomplete` instead. + +Apply this split in Dify UI: + +- `Autocomplete` — free-form text input with optional suggestions or completions. The input value may be custom and does not necessarily become a selected option. Use for search boxes, command-style suggestions, tag suggestions, and async text completion. +- `Combobox` — searchable picker whose value is one or more selected items from a collection. The chosen value is remembered by the root, and free-form text is not the final value. Use for model pickers, user pickers, dataset/document pickers, and multi-select chips. +- `Select` — closed-list picker without text entry. Use when the option set is small or already scannable and filtering is unnecessary. + +Composition rules: + +- Keep Base UI primitive semantics visible in the public API. Export compound parts such as `ComboboxInputGroup`, `ComboboxInput`, `ComboboxContent`, `ComboboxList`, `ComboboxItem`, and `ComboboxItemIndicator` instead of wrapping them into one business component. +- For `Combobox` multiple selection, follow the official chips pattern: `ComboboxInputGroup` contains `ComboboxChips`, `ComboboxValue` renders `ComboboxChip` items, and `ComboboxInput` remains inside the chips row. Chips should wrap and let the input group grow vertically instead of forcing horizontal overflow. +- Content primitives must own their Base UI `Portal` and use `z-1002` on `Positioner`, matching the overlay contract in `README.md`. +- Use `w-(--anchor-width)` with viewport-aware max-width for `Autocomplete` and `Combobox` popups. Do not add `min-w-(--anchor-width)` when it would defeat available-width clamping. + +[Autocomplete docs]: https://base-ui.com/react/components/autocomplete.md#usage-guidelines +[Combobox docs]: https://base-ui.com/react/components/combobox.md#usage-guidelines [docs]: https://base-ui.com/react/components/tooltip#infotips diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index cd24a0c078..bdeeec33cb 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -36,12 +36,12 @@ Importing from `@langgenius/dify-ui` (no subpath) is intentionally not supported ## Primitives -| Category | Subpath | Notes | -| -------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------- | -| Overlay | `./alert-dialog`, `./context-menu`, `./dialog`, `./dropdown-menu`, `./popover`, `./select`, `./toast`, `./tooltip` | Portalled. See [Overlay & portal contract] below. | -| Form | `./number-field`, `./slider`, `./switch` | Controlled / uncontrolled per Base UI defaults. | -| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. | -| Media | `./avatar`, `./button` | Button exposes `cva` variants. | +| Category | Subpath | Notes | +| -------- | -------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------- | +| Overlay | `./alert-dialog`, `./autocomplete`, `./combobox`, `./context-menu`, `./dialog`, `./dropdown-menu`, `./popover`, `./select`, `./toast`, `./tooltip` | Portalled. See [Overlay & portal contract] below. | +| Form | `./autocomplete`, `./combobox`, `./number-field`, `./slider`, `./switch` | Controlled / uncontrolled per Base UI defaults. | +| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. | +| Media | `./avatar`, `./button` | Button exposes `cva` variants. | Utilities: @@ -65,7 +65,7 @@ If a consumer uses Dify UI source files through the workspace, add an explicit s ## Overlay & portal contract -All overlay primitives (`dialog`, `alert-dialog`, `popover`, `dropdown-menu`, `context-menu`, `select`, `tooltip`, `toast`) render their content inside a [Base UI Portal] attached to `document.body`. This is the Base UI default — see the upstream [Portals][Base UI Portal] docs for the underlying behavior. Consumers **do not** need to wrap anything in a portal manually. +All overlay primitives (`dialog`, `alert-dialog`, `autocomplete`, `combobox`, `popover`, `dropdown-menu`, `context-menu`, `select`, `tooltip`, `toast`) render their content inside a [Base UI Portal] attached to `document.body`. This is the Base UI default — see the upstream [Portals][Base UI Portal] docs for the underlying behavior. Consumers **do not** need to wrap anything in a portal manually. ### Root isolation requirement @@ -83,14 +83,14 @@ Equivalent: any root element with `isolation: isolate` in CSS. Without it, overl Every overlay primitive uses a single, shared z-index. Do **not** override it at call sites. -| Layer | z-index | Where | -| ----------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- | -| Overlays (Dialog, AlertDialog, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop | -| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. | +| Layer | z-index | Where | +| ----------------------------------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- | +| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop | +| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. | -Rationale: during Dify's migration from legacy `portal-to-follow-elem` / `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins. +Rationale: during Dify's migration from legacy `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins. -See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history and the remaining legacy allowlist. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`. +See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`. ### Rules diff --git a/packages/dify-ui/package.json b/packages/dify-ui/package.json index 73c6c0bd22..20e94c7dee 100644 --- a/packages/dify-ui/package.json +++ b/packages/dify-ui/package.json @@ -13,6 +13,10 @@ "types": "./src/alert-dialog/index.tsx", "import": "./src/alert-dialog/index.tsx" }, + "./autocomplete": { + "types": "./src/autocomplete/index.tsx", + "import": "./src/autocomplete/index.tsx" + }, "./avatar": { "types": "./src/avatar/index.tsx", "import": "./src/avatar/index.tsx" @@ -21,6 +25,10 @@ "types": "./src/button/index.tsx", "import": "./src/button/index.tsx" }, + "./combobox": { + "types": "./src/combobox/index.tsx", + "import": "./src/combobox/index.tsx" + }, "./context-menu": { "types": "./src/context-menu/index.tsx", "import": "./src/context-menu/index.tsx" @@ -103,6 +111,7 @@ "@storybook/addon-themes": "catalog:", "@storybook/react-vite": "catalog:", "@tailwindcss/vite": "catalog:", + "@tanstack/react-virtual": "catalog:", "@types/react": "catalog:", "@types/react-dom": "catalog:", "@typescript/native-preview": "catalog:", diff --git a/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx new file mode 100644 index 0000000000..a7031c5b12 --- /dev/null +++ b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx @@ -0,0 +1,252 @@ +import type { ReactNode } from 'react' +import { render } from 'vitest-browser-react' +import { + Autocomplete, + AutocompleteClear, + AutocompleteContent, + AutocompleteEmpty, + AutocompleteGroup, + AutocompleteInput, + AutocompleteInputGroup, + AutocompleteItem, + AutocompleteItemIndicator, + AutocompleteItemText, + AutocompleteLabel, + AutocompleteList, + AutocompleteSeparator, + AutocompleteStatus, + AutocompleteTrigger, +} from '../index' + +const renderWithSafeViewport = (ui: ReactNode) => render( +
+ {ui} +
, +) + +const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement + +const renderAutocomplete = ({ + children, + open = false, + defaultValue = 'workflow', +}: { + children?: ReactNode + open?: boolean + defaultValue?: string +} = {}) => renderWithSafeViewport( + + {children ?? ( + <> + + + + + + + 2 suggestions + + + Workflow + + + + Dataset + + + No suggestions + + + )} + , +) + +describe('Autocomplete wrappers', () => { + describe('Input group and input', () => { + it('should apply medium input group and input classes by default', async () => { + const screen = await renderAutocomplete() + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-lg') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('px-3') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('system-sm-regular') + }) + + it('should apply large input group and input classes when large size is provided', async () => { + const screen = await renderAutocomplete({ + children: ( + + + + ), + }) + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-[10px]') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('px-4') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('system-md-regular') + }) + + it('should set input defaults and forward passthrough props', async () => { + const screen = await renderAutocomplete({ + children: ( + + + + ), + }) + + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('autocomplete', 'off') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('type', 'text') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('placeholder', 'Find a resource') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toBeRequired() + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('custom-input') + }) + }) + + describe('Controls', () => { + it('should provide fallback aria labels and decorative icons when labels are omitted', async () => { + const screen = await renderAutocomplete() + + await expect.element(screen.getByRole('button', { name: 'Clear autocomplete' })).toHaveAttribute('type', 'button') + await expect.element(screen.getByRole('button', { name: 'Open autocomplete suggestions' })).toHaveAttribute('type', 'button') + expect(screen.getByRole('button', { name: 'Clear autocomplete' }).element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true') + expect(screen.getByRole('button', { name: 'Open autocomplete suggestions' }).element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should preserve explicit labels and custom children', async () => { + const screen = await renderAutocomplete({ + children: ( + + + + reset + + + open + + + ), + }) + + expect(screen.getByRole('button', { name: 'Reset search' }).element()).toContainElement(screen.getByTestId('custom-clear').element()) + expect(screen.getByRole('button', { name: 'Show suggestions' }).element()).toContainElement(screen.getByTestId('custom-trigger').element()) + expect(screen.getByRole('button', { name: 'Reset search' }).element().querySelector('.i-ri-close-line')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Show suggestions' }).element().querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument() + }) + + it('should rely on aria-labelledby when provided instead of injecting fallback labels', async () => { + const screen = await renderAutocomplete({ + children: ( + <> + Clear from label + Trigger from label + + + + + + + ), + }) + + await expect.element(screen.getByRole('button', { name: 'Clear from label' })).not.toHaveAttribute('aria-label') + await expect.element(screen.getByRole('button', { name: 'Trigger from label' })).not.toHaveAttribute('aria-label') + }) + }) + + describe('Content and options', () => { + it('should use default overlay placement and Dify popup classes', async () => { + const screen = await renderAutocomplete({ open: true }) + + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'bottom') + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-align', 'start') + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('rounded-xl') + await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('w-(--anchor-width)') + await expect.element(screen.getByRole('listbox', { name: 'autocomplete list' })).toHaveClass('scroll-py-1') + }) + + it('should apply custom placement side and passthrough popup props', async () => { + const onPopupClick = vi.fn() + const screen = await renderWithSafeViewport( + + + + + + + + Workflow + + + + , + ) + + asHTMLElement(screen.getByRole('dialog', { name: 'autocomplete popup' }).element()).click() + + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'top') + expect(onPopupClick).toHaveBeenCalledTimes(1) + }) + + it('should render item text indicator status and empty wrappers with design classes', async () => { + const screen = await renderAutocomplete({ open: true }) + + await expect.element(screen.getByText('Workflow')).toHaveClass('system-sm-medium') + await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary') + await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular') + expect(screen.getByText('Workflow').element().parentElement?.querySelector('.i-ri-arrow-right-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should forward custom classes to label separator item text and indicator', async () => { + const screen = await renderWithSafeViewport( + + + + + + + + Resources + + + Workflow + + + + + + , + ) + + await expect.element(screen.getByText('Resources')).toHaveClass('custom-label') + await expect.element(screen.getByTestId('separator')).toHaveClass('custom-separator') + await expect.element(screen.getByRole('option', { name: 'Workflow' })).toHaveClass('custom-item') + await expect.element(screen.getByText('Workflow')).toHaveClass('custom-text') + await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator') + }) + }) +}) diff --git a/packages/dify-ui/src/autocomplete/index.stories.tsx b/packages/dify-ui/src/autocomplete/index.stories.tsx new file mode 100644 index 0000000000..71c7c6607d --- /dev/null +++ b/packages/dify-ui/src/autocomplete/index.stories.tsx @@ -0,0 +1,721 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import type { Virtualizer } from '@tanstack/react-virtual' +import type { RefObject } from 'react' +import { useVirtualizer } from '@tanstack/react-virtual' +import { useEffect, useMemo, useRef, useState } from 'react' +import { + Autocomplete, + AutocompleteClear, + AutocompleteCollection, + AutocompleteContent, + AutocompleteEmpty, + AutocompleteGroup, + AutocompleteInput, + AutocompleteInputGroup, + AutocompleteItem, + AutocompleteItemText, + AutocompleteLabel, + AutocompleteList, + AutocompleteSeparator, + AutocompleteStatus, + AutocompleteTrigger, + useAutocompleteFilter, + useAutocompleteFilteredItems, +} from '.' +import { cn } from '../cn' + +type Suggestion = { + value: string + label: string + description?: string + icon?: string + meta?: string +} + +type SuggestionGroup = { + label: string + items: Suggestion[] +} + +const inputWidth = 'w-80' + +type StoryVirtualizer = Virtualizer + +const scrollHighlightedVirtualItem = ( + item: unknown, + { + reason, + index, + }: { + reason: 'keyboard' | 'pointer' | 'none' + index: number + }, + virtualizer: StoryVirtualizer | null, +) => { + if (!item || !virtualizer) + return + + const isStart = index === 0 + const isEnd = index === virtualizer.options.count - 1 + const shouldScroll = reason === 'none' || (reason === 'keyboard' && (isStart || isEnd)) + + if (shouldScroll) { + queueMicrotask(() => { + virtualizer.scrollToIndex(index, { align: isEnd ? 'start' : 'end' }) + }) + } +} + +const tagSuggestions: Suggestion[] = [ + { value: 'feature', label: 'feature', description: 'Product work and launch notes' }, + { value: 'fix', label: 'fix', description: 'Bug fixes and regressions' }, + { value: 'docs', label: 'docs', description: 'Documentation updates' }, + { value: 'internal', label: 'internal', description: 'Workspace-only notes' }, + { value: 'mobile', label: 'mobile', description: 'Mobile app issues' }, + { value: 'component: autocomplete', label: 'component: autocomplete', description: 'Base UI primitive wrapper' }, + { value: 'component: combobox', label: 'component: combobox', description: 'Filterable predefined selection' }, + { value: 'component: select', label: 'component: select', description: 'Compact predefined selection' }, +] + +const promptCompletions: Suggestion[] = [ + { value: 'summarize this conversation', label: 'summarize this conversation' }, + { value: 'summarize this dataset with citations', label: 'summarize this dataset with citations' }, + { value: 'summarize this workflow run for an operator', label: 'summarize this workflow run for an operator' }, + { value: 'summarize this support ticket in 3 bullets', label: 'summarize this support ticket in 3 bullets' }, +] + +const workflowSuggestions: Suggestion[] = [ + { value: 'http-request', label: 'HTTP Request', description: 'Call an external API', icon: 'i-ri-global-line', meta: 'Tool' }, + { value: 'knowledge-retrieval', label: 'Knowledge Retrieval', description: 'Search configured datasets', icon: 'i-ri-database-2-line', meta: 'Tool' }, + { value: 'code-execution', label: 'Code Execution', description: 'Run sandboxed snippets', icon: 'i-ri-code-s-slash-line', meta: 'Tool' }, + { value: 'template-transform', label: 'Template Transform', description: 'Compose variables into output', icon: 'i-ri-braces-line', meta: 'Tool' }, + { value: 'question-classifier', label: 'Question Classifier', description: 'Route by intent', icon: 'i-ri-git-branch-line', meta: 'Tool' }, + { value: 'parameter-extractor', label: 'Parameter Extractor', description: 'Extract typed values', icon: 'i-ri-list-check-3', meta: 'Tool' }, + { value: 'answer-node', label: 'Answer Node', description: 'Return a final assistant answer', icon: 'i-ri-message-3-line', meta: 'Node' }, + { value: 'iteration-node', label: 'Iteration Node', description: 'Run a loop over array items', icon: 'i-ri-repeat-line', meta: 'Node' }, + { value: 'variable-assigner', label: 'Variable Assigner', description: 'Persist intermediate state', icon: 'i-ri-pencil-ruler-2-line', meta: 'Node' }, +] + +const groupedSuggestions: SuggestionGroup[] = [ + { + label: 'Tags', + items: tagSuggestions.slice(0, 5), + }, + { + label: 'Workflow Suggestions', + items: workflowSuggestions.slice(0, 5), + }, + { + label: 'Prompt Starters', + items: promptCompletions.slice(0, 3), + }, +] + +const commandGroups: SuggestionGroup[] = [ + { + label: 'App', + items: [ + { value: '/run', label: 'Run workflow', description: 'Execute the current draft', icon: 'i-ri-play-circle-line' }, + { value: '/publish', label: 'Publish app', description: 'Ship the current configuration', icon: 'i-ri-upload-cloud-2-line' }, + { value: '/trace', label: 'Open trace', description: 'Inspect the latest workflow run', icon: 'i-ri-route-line' }, + ], + }, + { + label: 'Workspace', + items: [ + { value: '/dataset', label: 'Search datasets', description: 'Find knowledge attached to this app', icon: 'i-ri-database-line' }, + { value: '/members', label: 'Invite members', description: 'Open workspace access settings', icon: 'i-ri-user-add-line' }, + { value: '/usage', label: 'View usage', description: 'Open model and workflow usage', icon: 'i-ri-bar-chart-line' }, + ], + }, +] + +const remoteSuggestions: Suggestion[] = [ + { value: 'agent-builder', label: 'Agent Builder', description: 'Workspace app' }, + { value: 'agent-observability', label: 'Agent Observability', description: 'Dataset' }, + { value: 'agent-routing-dataset', label: 'Agent Routing Dataset', description: 'Knowledge source' }, +] + +const virtualizedSuggestions: Suggestion[] = Array.from({ length: 1000 }, (_, index) => { + const family = ['workflow', 'dataset', 'prompt', 'tool'][index % 4]! + const number = new Intl.NumberFormat('en-US', { + minimumIntegerDigits: 4, + }).format(index + 1) + + return { + value: `${family}-${index + 1}`, + label: `${family} suggestion ${number}`, + description: `Free-form autocomplete result from ${family} search`, + icon: family === 'dataset' + ? 'i-ri-database-2-line' + : family === 'prompt' + ? 'i-ri-text-snippet' + : family === 'tool' + ? 'i-ri-tools-line' + : 'i-ri-flow-chart', + meta: family, + } +}) + +const getSuggestionLabel = (item: Suggestion) => item.label + +const SuggestionItem = ({ + item, + index, + dense, +}: { + item: Suggestion + index?: number + dense?: boolean +}) => ( + + {item.icon && +) + +const TagSuggestionItem = ({ + item, + index, +}: { + item: Suggestion + index?: number +}) => ( + + {item.label} + {item.description && {item.description}} + +) + +const BasicTagAutocomplete = ({ + size = 'medium', +}: { + size?: 'small' | 'medium' | 'large' +}) => ( + + + + + + {(item: Suggestion, index: number) => ( + + )} + + No tag suggestion. Keep the typed value. + + +) + +const GroupedSuggestionList = () => { + const groups = useAutocompleteFilteredItems() + + return ( + + {groups.map((group, groupIndex) => ( + + {groupIndex > 0 && } + {group.label} + + {(item: Suggestion) => ( + + )} + + + ))} + + ) +} + +const CommandPaletteList = () => { + const groups = useAutocompleteFilteredItems() + + return ( + + {groups.map((group, groupIndex) => ( + + {groupIndex > 0 && } + {group.label} + + {(item: Suggestion) => ( + + + {item.icon && + + Enter + + + )} + + + ))} + + ) +} + +const LimitedStatus = ({ + total, +}: { + total: number +}) => { + const items = useAutocompleteFilteredItems() + const hidden = Math.max(0, total - items.length) + + return hidden > 0 + ? `${hidden} more suggestions hidden. Refine the query to narrow results.` + : `${items.length} suggestions available.` +} + +const AsyncSearchDemo = () => { + const [value, setValue] = useState('agent') + const [loading, setLoading] = useState(false) + const [items, setItems] = useState(remoteSuggestions) + + useEffect(() => { + setLoading(true) + const timeout = window.setTimeout(() => { + setItems( + value.trim() + ? remoteSuggestions.filter(item => item.label.toLowerCase().includes(value.trim().toLowerCase())) + : remoteSuggestions, + ) + setLoading(false) + }, 500) + + return () => window.clearTimeout(timeout) + }, [value]) + + return ( +
+ + + + + + {loading ? 'Loading suggestions…' : `${items.length} remote suggestions`} + + + {(item: Suggestion, index: number) => ( + + )} + + No remote suggestion. Keep the typed query. + + +
+ ) +} + +const VirtualizedSuggestionList = ({ + virtualizerRef, +}: { + virtualizerRef: RefObject +}) => { + const scrollRef = useRef(null) + const filteredItems = useAutocompleteFilteredItems() + const virtualizer = useVirtualizer({ + count: filteredItems.length, + getScrollElement: () => scrollRef.current, + estimateSize: () => 44, + overscan: 6, + }) + + useEffect(() => { + virtualizerRef.current = virtualizer + + return () => { + virtualizerRef.current = null + } + }, [virtualizer, virtualizerRef]) + + return ( +
+ + {virtualizer.getVirtualItems().map((virtualItem) => { + const item = filteredItems[virtualItem.index] + + if (!item) + return null + + return ( +
+ +
+ ) + })} +
+
+ ) +} + +const VirtualizedStatus = () => { + const filteredItems = useAutocompleteFilteredItems() + + return ( + + {filteredItems.length} + {' '} + matching suggestions. Selecting one only replaces the input text. + + ) +} + +const FuzzyHighlight = ({ + text, + query, +}: { + text: string + query: string +}) => { + const parts = useMemo(() => { + const trimmed = query.trim() + + if (!trimmed) + return [text] + + const escaped = trimmed.slice(0, 80).replace(/[.*+?^${}()|[\]\\]/g, '\\$&') + return text.split(new RegExp(`(${escaped})`, 'i')) + }, [query, text]) + + return ( + <> + {parts.map((part, index) => ( + part.toLowerCase() === query.trim().toLowerCase() + ? {part} + : part + ))} + + ) +} + +const FuzzyMatchingDemo = () => { + const [value, setValue] = useState('retr') + const { contains } = useAutocompleteFilter({ sensitivity: 'base' }) + + return ( +
+ + + + + + {(item: Suggestion, index: number) => ( + + {item.icon && + )} + + No workflow suggestion. Keep typing freely. + + +
+ ) +} + +const meta = { + title: 'Base/UI/Autocomplete', + component: Autocomplete, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Compound autocomplete built on Base UI Autocomplete. Use it for free-form inputs where suggestions can replace or complete the typed text, but selection is not persistent state.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +export const SearchTags: Story = { + render: () => ( +
+ +
+ ), +} + +export const Sizes: Story = { + render: () => ( +
+ {(['small', 'medium', 'large'] as const).map(size => ( +
+ +
+ ))} +
+ ), +} + +export const InlineAutocomplete: Story = { + render: () => ( +
+ + + + + + {(item: Suggestion, index: number) => ( + + )} + + No inline completion. Continue typing freely. + + +
+ ), +} + +export const GroupedSuggestions: Story = { + render: () => ( +
+ + + + + + No suggestion. Use the text as entered. + + +
+ ), +} + +export const FuzzyMatching: Story = { + render: () => , +} + +export const LimitResults: Story = { + render: () => ( +
+ + + + + + + + + {(item: Suggestion, index: number) => ( + + )} + + No suggestion. Submit the typed text instead. + + +
+ ), +} + +export const CommandPalette: Story = { + render: () => ( +
+ + + + + +
+ ), +} + +const VirtualizedLongSuggestionsDemo = () => { + const virtualizerRef = useRef(null) + + return ( +
+ { + scrollHighlightedVirtualItem(item, details, virtualizerRef.current) + }} + > + + + + + + No suggestion. Free-form text is still valid. + + +
+ ) +} + +export const VirtualizedLongSuggestions: Story = { + render: () => , +} + +export const AsyncSearch: Story = { + render: () => , +} + +export const Empty: Story = { + render: () => ( +
+ + + + + + {(item: Suggestion, index: number) => ( + + )} + + No tag suggestion. The custom text remains valid. + + +
+ ), +} + +export const DisabledAndReadOnly: Story = { + render: () => ( +
+ + + + + + + + + {(item: Suggestion, index: number) => ( + + )} + + + + + + + + + + + + {(item: Suggestion, index: number) => ( + + )} + + + +
+ ), +} diff --git a/packages/dify-ui/src/autocomplete/index.tsx b/packages/dify-ui/src/autocomplete/index.tsx new file mode 100644 index 0000000000..16c4b19673 --- /dev/null +++ b/packages/dify-ui/src/autocomplete/index.tsx @@ -0,0 +1,381 @@ +'use client' + +import type { VariantProps } from 'class-variance-authority' +import type { HTMLAttributes, ReactNode } from 'react' +import type { Placement } from '../placement' +import { Autocomplete as BaseAutocomplete } from '@base-ui/react/autocomplete' +import { cva } from 'class-variance-authority' +import { cn } from '../cn' +import { + overlayIndicatorClassName, + overlayLabelClassName, + overlayPopupAnimationClassName, + overlaySeparatorClassName, +} from '../overlay-shared' +import { parsePlacement } from '../placement' + +export type { Placement } + +export const Autocomplete = BaseAutocomplete.Root +export const AutocompleteValue = BaseAutocomplete.Value +export const AutocompleteGroup = BaseAutocomplete.Group +export const AutocompleteCollection = BaseAutocomplete.Collection +export const AutocompleteRow = BaseAutocomplete.Row +export const useAutocompleteFilter = BaseAutocomplete.useFilter +export const useAutocompleteFilteredItems = BaseAutocomplete.useFilteredItems + +export type AutocompleteRootProps = BaseAutocomplete.Root.Props +export type AutocompleteRootChangeEventDetails = BaseAutocomplete.Root.ChangeEventDetails +export type AutocompleteRootHighlightEventDetails = BaseAutocomplete.Root.HighlightEventDetails + +const autocompletePopupClassName = [ + 'w-(--anchor-width) max-w-[min(28rem,var(--available-width))] overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg outline-hidden', + 'data-side-top:origin-bottom data-side-bottom:origin-top data-side-left:origin-right data-side-right:origin-left', +] + +const autocompleteListClassName = [ + 'max-h-[min(20rem,var(--available-height))] overflow-y-auto overflow-x-hidden overscroll-contain p-1 outline-hidden scroll-py-1', + 'data-empty:max-h-none data-empty:p-0', +] + +const autocompleteItemClassName = [ + 'mx-1 flex min-h-8 cursor-pointer select-none items-center gap-2 rounded-lg px-2 py-1.5 text-text-secondary outline-hidden transition-colors', + 'hover:bg-state-base-hover-alt hover:text-text-primary', + 'data-highlighted:bg-state-base-hover data-highlighted:text-text-primary', + 'data-disabled:cursor-not-allowed data-disabled:opacity-30 data-disabled:hover:bg-transparent data-disabled:hover:text-text-secondary', + 'motion-reduce:transition-none', +] + +const autocompleteInputGroupVariants = cva( + [ + 'group/autocomplete flex w-full min-w-0 items-center border border-transparent bg-components-input-bg-normal text-components-input-text-filled shadow-none outline-hidden transition-[background-color,border-color,box-shadow]', + 'hover:border-components-input-border-hover hover:bg-components-input-bg-hover', + 'focus-within:border-components-input-border-active focus-within:bg-components-input-bg-active focus-within:shadow-xs', + 'data-focused:border-components-input-border-active data-focused:bg-components-input-bg-active data-focused:shadow-xs', + 'data-disabled:cursor-not-allowed data-disabled:border-transparent data-disabled:bg-components-input-bg-disabled data-disabled:text-components-input-text-filled-disabled', + 'data-disabled:hover:border-transparent data-disabled:hover:bg-components-input-bg-disabled', + 'data-readonly:shadow-none data-readonly:hover:border-transparent data-readonly:hover:bg-components-input-bg-normal', + 'motion-reduce:transition-none', + ], + { + variants: { + size: { + small: 'h-6 rounded-md', + medium: 'h-8 rounded-lg', + large: 'h-9 rounded-[10px]', + }, + }, + defaultVariants: { + size: 'medium', + }, + }, +) + +export type AutocompleteSize = NonNullable['size']> + +export type AutocompleteInputGroupProps + = BaseAutocomplete.InputGroup.Props + & VariantProps + +export function AutocompleteInputGroup({ + className, + size = 'medium', + ...props +}: AutocompleteInputGroupProps) { + return ( + + ) +} + +const autocompleteInputVariants = cva( + [ + 'w-0 min-w-0 flex-1 appearance-none border-0 bg-transparent text-components-input-text-filled caret-primary-600 outline-hidden', + 'placeholder:text-components-input-text-placeholder', + 'disabled:cursor-not-allowed disabled:text-components-input-text-filled-disabled disabled:placeholder:text-components-input-text-disabled', + 'data-readonly:cursor-default', + ], + { + variants: { + size: { + small: 'px-2 py-1 system-xs-regular', + medium: 'px-3 py-[7px] system-sm-regular', + large: 'px-4 py-2 system-md-regular', + }, + }, + defaultVariants: { + size: 'medium', + }, + }, +) + +export type AutocompleteInputProps + = Omit + & VariantProps + +export function AutocompleteInput({ + className, + size = 'medium', + type = 'text', + autoComplete = 'off', + ...props +}: AutocompleteInputProps) { + return ( + + ) +} + +const autocompleteControlVariants = cva( + [ + 'flex shrink-0 touch-manipulation items-center justify-center rounded-md text-text-tertiary outline-hidden transition-colors', + 'hover:bg-components-input-bg-hover hover:text-text-secondary focus-visible:bg-components-input-bg-hover focus-visible:text-text-secondary', + 'focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:ring-inset', + 'disabled:cursor-not-allowed disabled:hover:bg-transparent disabled:hover:text-text-tertiary disabled:focus-visible:bg-transparent disabled:focus-visible:ring-0', + 'group-data-disabled/autocomplete:cursor-not-allowed group-data-disabled/autocomplete:hover:bg-transparent group-data-disabled/autocomplete:focus-visible:bg-transparent group-data-disabled/autocomplete:focus-visible:ring-0', + 'group-data-readonly/autocomplete:hidden', + 'motion-reduce:transition-none', + ], + { + variants: { + size: { + small: 'mr-1 size-4', + medium: 'mr-1.5 size-5', + large: 'mr-2 size-5', + }, + }, + defaultVariants: { + size: 'medium', + }, + }, +) + +export type AutocompleteControlProps + = Omit + & VariantProps + & { className?: string } + +export function AutocompleteTrigger({ + className, + children, + size = 'medium', + type = 'button', + ...props +}: AutocompleteControlProps) { + return ( + + {children ?? + ) +} + +export type AutocompleteClearProps + = Omit + & VariantProps + & { className?: string } + +export function AutocompleteClear({ + className, + children, + size = 'medium', + type = 'button', + ...props +}: AutocompleteClearProps) { + return ( + + {children ?? + ) +} + +export function AutocompleteIcon({ + className, + children, + ...props +}: BaseAutocomplete.Icon.Props) { + return ( + + {children ?? + ) +} + +type AutocompleteContentProps = { + children: ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + portalProps?: Omit + positionerProps?: Omit< + BaseAutocomplete.Positioner.Props, + 'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset' + > + popupProps?: Omit< + BaseAutocomplete.Popup.Props, + 'children' | 'className' + > +} + +export function AutocompleteContent({ + children, + placement = 'bottom-start', + sideOffset = 4, + alignOffset = 0, + className, + popupClassName, + portalProps, + positionerProps, + popupProps, +}: AutocompleteContentProps) { + const { side, align } = parsePlacement(placement) + + return ( + + + + {children} + + + + ) +} + +export function AutocompleteList({ + className, + ...props +}: BaseAutocomplete.List.Props) { + return ( + + ) +} + +export function AutocompleteItem({ + className, + ...props +}: BaseAutocomplete.Item.Props) { + return ( + + ) +} + +export type AutocompleteItemTextProps = HTMLAttributes + +export function AutocompleteItemText({ + className, + ...props +}: AutocompleteItemTextProps) { + return ( + + ) +} + +export function AutocompleteLabel({ + className, + ...props +}: BaseAutocomplete.GroupLabel.Props) { + return ( + + ) +} + +export function AutocompleteSeparator({ + className, + ...props +}: BaseAutocomplete.Separator.Props) { + return ( + + ) +} + +export function AutocompleteEmpty({ + className, + ...props +}: BaseAutocomplete.Empty.Props) { + return ( + + ) +} + +export function AutocompleteStatus({ + className, + ...props +}: BaseAutocomplete.Status.Props) { + return ( + + ) +} + +export function AutocompleteItemIndicator({ + className, + children, + ...props +}: HTMLAttributes) { + return ( + + {children ?? + ) +} diff --git a/packages/dify-ui/src/combobox/__tests__/index.spec.tsx b/packages/dify-ui/src/combobox/__tests__/index.spec.tsx new file mode 100644 index 0000000000..705ebe9601 --- /dev/null +++ b/packages/dify-ui/src/combobox/__tests__/index.spec.tsx @@ -0,0 +1,363 @@ +import type { ReactNode } from 'react' +import { render } from 'vitest-browser-react' +import { + Combobox, + ComboboxChip, + ComboboxChipRemove, + ComboboxChips, + ComboboxClear, + ComboboxContent, + ComboboxEmpty, + ComboboxGroup, + ComboboxGroupLabel, + ComboboxInput, + ComboboxInputGroup, + ComboboxInputTrigger, + ComboboxItem, + ComboboxItemIndicator, + ComboboxItemText, + ComboboxLabel, + ComboboxList, + ComboboxSeparator, + ComboboxStatus, + ComboboxTrigger, + ComboboxValue, +} from '../index' + +const renderWithSafeViewport = (ui: ReactNode) => render( +
+ {ui} +
, +) + +const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement + +const renderSelectLikeCombobox = ({ + children, + open = false, +}: { + children?: ReactNode + open?: boolean +} = {}) => renderWithSafeViewport( + + {children ?? ( + <> + Resource type + + + + + 2 options + + + Workflow + + + + Dataset + + + No options + + + )} + , +) + +const renderInputCombobox = ({ + children, + open = false, +}: { + children?: ReactNode + open?: boolean +} = {}) => renderWithSafeViewport( + + {children ?? ( + <> + + + + + + + + + Workflow + + + + + + )} + , +) + +describe('Combobox wrappers', () => { + describe('Select-like trigger', () => { + it('should render label and apply medium trigger classes by default', async () => { + const screen = await renderSelectLikeCombobox() + + await expect.element(screen.getByText('Resource type')).toHaveClass('system-sm-medium') + await expect.element(screen.getByRole('combobox', { name: 'Resource type' })).toHaveClass('rounded-lg') + await expect.element(screen.getByRole('combobox', { name: 'Resource type' })).toHaveClass('system-sm-regular') + }) + + it('should apply small and large trigger size variants', async () => { + const smallScreen = await renderSelectLikeCombobox({ + children: ( + + + + ), + }) + + await expect.element(smallScreen.getByRole('combobox', { name: 'Small resource type' })).toHaveClass('rounded-md') + await expect.element(smallScreen.getByRole('combobox', { name: 'Small resource type' })).toHaveClass('system-xs-regular') + + const largeScreen = await renderSelectLikeCombobox({ + children: ( + + + + ), + }) + + await expect.element(largeScreen.getByRole('combobox', { name: 'Large resource type' })).toHaveClass('rounded-[10px]') + await expect.element(largeScreen.getByRole('combobox', { name: 'Large resource type' })).toHaveClass('system-md-regular') + }) + + it('should render default trigger icon and support hiding it', async () => { + const withIcon = await renderSelectLikeCombobox() + + expect(withIcon.getByTestId('trigger').element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true') + + const withoutIcon = await renderSelectLikeCombobox({ + children: ( + + + + ), + }) + + expect(withoutIcon.getByRole('combobox', { name: 'Resource type without icon' }).element().querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument() + }) + }) + + describe('Input group and controls', () => { + it('should apply medium input group and input classes by default', async () => { + const screen = await renderInputCombobox() + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-lg') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('px-3') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('system-sm-regular') + }) + + it('should apply large input group and input classes when large size is provided', async () => { + const screen = await renderInputCombobox({ + children: ( + + + + ), + }) + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-[10px]') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('px-4') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('system-md-regular') + }) + + it('should set input defaults and forward passthrough props', async () => { + const screen = await renderInputCombobox({ + children: ( + + + + ), + }) + + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('autocomplete', 'off') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('type', 'text') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('placeholder', 'Find a resource') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toBeRequired() + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('custom-input') + }) + + it('should provide fallback aria labels and decorative icons for input controls', async () => { + const screen = await renderInputCombobox() + + await expect.element(screen.getByRole('button', { name: 'Clear combobox' })).toHaveAttribute('type', 'button') + await expect.element(screen.getByRole('button', { name: 'Open combobox options' })).toHaveAttribute('type', 'button') + expect(screen.getByRole('button', { name: 'Clear combobox' }).element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true') + expect(screen.getByRole('button', { name: 'Open combobox options' }).element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should rely on aria-labelledby when provided instead of injecting fallback labels', async () => { + const screen = await renderInputCombobox({ + children: ( + <> + Clear from label + Trigger from label + + + + + + + ), + }) + + await expect.element(screen.getByRole('button', { name: 'Clear from label' })).not.toHaveAttribute('aria-label') + await expect.element(screen.getByRole('button', { name: 'Trigger from label' })).not.toHaveAttribute('aria-label') + }) + }) + + describe('Content and options', () => { + it('should use default overlay placement and Dify popup classes', async () => { + const screen = await renderSelectLikeCombobox({ open: true }) + + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'bottom') + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-align', 'start') + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('rounded-xl') + await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('w-(--anchor-width)') + await expect.element(screen.getByRole('listbox', { name: 'combobox list' })).toHaveClass('scroll-py-1') + }) + + it('should apply custom placement side and passthrough popup props', async () => { + const onPopupClick = vi.fn() + const screen = await renderWithSafeViewport( + + + + + + + + Workflow + + + + , + ) + + asHTMLElement(screen.getByRole('dialog', { name: 'combobox popup' }).element()).click() + + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'top') + expect(onPopupClick).toHaveBeenCalledTimes(1) + }) + + it('should render item text indicator status and empty wrappers with design classes', async () => { + const screen = await renderSelectLikeCombobox({ open: true }) + + await expect.element(screen.getByTestId('list').getByText('Workflow')).toHaveClass('system-sm-medium') + await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary') + await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular') + expect(screen.getByTestId('list').getByText('Workflow').element().parentElement?.querySelector('.i-ri-check-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should forward custom classes to group label separator item text and indicator', async () => { + const screen = await renderWithSafeViewport( + + + + + + + + Resources + + + Workflow + + + + + + , + ) + + await expect.element(screen.getByText('Resources')).toHaveClass('custom-label') + await expect.element(screen.getByTestId('separator')).toHaveClass('custom-separator') + await expect.element(screen.getByRole('option', { name: 'Workflow' })).toHaveClass('custom-item') + await expect.element(screen.getByTestId('custom-list').getByText('Workflow')).toHaveClass('custom-text') + await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator') + }) + }) + + describe('Multiple selection chips', () => { + it('should render chip wrappers and default remove button label', async () => { + const screen = await renderWithSafeViewport( + + + + {(selectedValue: string[]) => ( + + {selectedValue.map(item => ( + + {item} + + + ))} + + )} + + + + , + ) + + await expect.element(screen.getByTestId('chips')).toHaveClass('custom-chips') + await expect.element(screen.getByText('maya').element().parentElement!).toHaveClass('custom-chip') + await expect.element(screen.getByRole('button', { name: 'Remove selected item' })).toHaveAttribute('type', 'button') + expect(screen.getByTestId('remove-chip').element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should preserve chip remove aria-labelledby over fallback label', async () => { + const screen = await renderWithSafeViewport( + + + + {(selectedValue: string[]) => ( + + {selectedValue.map(item => ( + + Remove Maya + + + ))} + + )} + + + + , + ) + + await expect.element(screen.getByRole('button', { name: 'Remove Maya' })).not.toHaveAttribute('aria-label') + }) + }) +}) diff --git a/packages/dify-ui/src/combobox/index.stories.tsx b/packages/dify-ui/src/combobox/index.stories.tsx new file mode 100644 index 0000000000..f2b5f4d4c6 --- /dev/null +++ b/packages/dify-ui/src/combobox/index.stories.tsx @@ -0,0 +1,618 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import type { Virtualizer } from '@tanstack/react-virtual' +import type { RefObject } from 'react' +import { useVirtualizer } from '@tanstack/react-virtual' +import { useEffect, useRef, useState } from 'react' +import { + Combobox, + ComboboxChip, + ComboboxChipRemove, + ComboboxChips, + ComboboxClear, + ComboboxCollection, + ComboboxContent, + ComboboxEmpty, + ComboboxGroup, + ComboboxGroupLabel, + ComboboxInput, + ComboboxInputGroup, + ComboboxInputTrigger, + ComboboxItem, + ComboboxItemIndicator, + ComboboxItemText, + ComboboxLabel, + ComboboxList, + ComboboxSeparator, + ComboboxStatus, + ComboboxTrigger, + ComboboxValue, + useComboboxFilteredItems, +} from '.' +import { cn } from '../cn' + +type Option = { + value: string + label: string + meta?: string + icon?: string + disabled?: boolean +} + +type OptionGroup = { + label: string + items: Option[] +} + +const fieldWidth = 'w-80' +const wideFieldWidth = 'w-[520px]' +const nativeFieldLabelClassName = 'mb-1 block text-text-secondary system-sm-medium' + +type StoryVirtualizer = Virtualizer + +const scrollHighlightedVirtualItem = ( + item: unknown, + { + reason, + index, + }: { + reason: 'keyboard' | 'pointer' | 'none' + index: number + }, + virtualizer: StoryVirtualizer | null, +) => { + if (!item || !virtualizer) + return + + const isStart = index === 0 + const isEnd = index === virtualizer.options.count - 1 + const shouldScroll = reason === 'none' || (reason === 'keyboard' && (isStart || isEnd)) + + if (shouldScroll) { + queueMicrotask(() => { + virtualizer.scrollToIndex(index, { align: isEnd ? 'start' : 'end' }) + }) + } +} + +const providerOptions: Option[] = [ + { value: 'openai', label: 'OpenAI', meta: 'GPT-5, GPT-4.1', icon: 'i-ri-openai-fill' }, + { value: 'anthropic', label: 'Anthropic', meta: 'Claude Opus, Sonnet', icon: 'i-ri-sparkling-2-line' }, + { value: 'google', label: 'Google', meta: 'Gemini 2.5', icon: 'i-ri-google-fill' }, + { value: 'azure-openai', label: 'Azure OpenAI', meta: 'Enterprise workspace', icon: 'i-ri-microsoft-fill' }, + { value: 'localai', label: 'LocalAI', meta: 'Self-hosted endpoint', icon: 'i-ri-server-line', disabled: true }, +] + +const dataSourceOptions: Option[] = [ + { value: 'knowledge-base', label: 'Knowledge Base', meta: 'Vector index', icon: 'i-ri-database-2-line' }, + { value: 'notion', label: 'Notion', meta: 'Synced pages', icon: 'i-ri-notion-fill' }, + { value: 'website', label: 'Website crawler', meta: 'Public URLs', icon: 'i-ri-global-line' }, + { value: 's3', label: 'S3 bucket', meta: 'Private files', icon: 'i-ri-cloud-line' }, + { value: 'slack', label: 'Slack', meta: 'Channel history', icon: 'i-ri-slack-fill' }, +] + +const reviewerOptions: Option[] = [ + { value: 'maya', label: 'Maya Chen', meta: 'Product owner' }, + { value: 'liam', label: 'Liam Brooks', meta: 'Prompt engineer' }, + { value: 'nora', label: 'Nora Park', meta: 'Data steward' }, + { value: 'owen', label: 'Owen Reed', meta: 'Security reviewer' }, + { value: 'yuki', label: 'Yuki Tanaka', meta: 'ML engineer' }, +] + +const toolGroups: OptionGroup[] = [ + { + label: 'Retrieval', + items: [ + { value: 'dataset-search', label: 'Dataset search', meta: 'Search workspace knowledge', icon: 'i-ri-search-eye-line' }, + { value: 'web-scraper', label: 'Web scraper', meta: 'Fetch public pages', icon: 'i-ri-global-line' }, + ], + }, + { + label: 'Actions', + items: [ + { value: 'http-request', label: 'HTTP request', meta: 'Call external APIs', icon: 'i-ri-terminal-box-line' }, + { value: 'code-runner', label: 'Code runner', meta: 'Execute sandboxed scripts', icon: 'i-ri-code-s-slash-line' }, + ], + }, + { + label: 'Operations', + items: [ + { value: 'human-review', label: 'Human review', meta: 'Assign approval task', icon: 'i-ri-user-voice-line' }, + { value: 'audit-log', label: 'Audit log', meta: 'Record workflow events', icon: 'i-ri-file-list-3-line' }, + ], + }, +] + +const tagOptions: Option[] = [ + { value: 'rag', label: 'RAG' }, + { value: 'agent', label: 'Agent' }, + { value: 'production', label: 'Production' }, + { value: 'evaluation', label: 'Evaluation' }, + { value: 'finance', label: 'Finance' }, + { value: 'support', label: 'Support' }, +] + +const directoryOptions: Option[] = [ + { value: 'maya-chen', label: 'Maya Chen', meta: 'Product owner · maya@example.com', icon: 'i-ri-user-3-line' }, + { value: 'liam-brooks', label: 'Liam Brooks', meta: 'Prompt engineer · liam@example.com', icon: 'i-ri-user-3-line' }, + { value: 'nora-park', label: 'Nora Park', meta: 'Data steward · nora@example.com', icon: 'i-ri-user-3-line' }, + { value: 'owen-reed', label: 'Owen Reed', meta: 'Security reviewer · owen@example.com', icon: 'i-ri-shield-user-line' }, + { value: 'yuki-tanaka', label: 'Yuki Tanaka', meta: 'ML engineer · yuki@example.com', icon: 'i-ri-user-3-line' }, + { value: 'ava-martin', label: 'Ava Martin', meta: 'Support lead · ava@example.com', icon: 'i-ri-customer-service-2-line' }, +] + +const emptyOptions: Option[] = [ + { value: 'billing', label: 'Billing connector' }, + { value: 'zendesk', label: 'Zendesk' }, + { value: 'github', label: 'GitHub issues' }, +] + +const modelCatalogOptions: Option[] = Array.from({ length: 1000 }, (_, index) => { + const provider = ['OpenAI', 'Anthropic', 'Google', 'Mistral', 'DeepSeek'][index % 5]! + const family = ['chat', 'reasoning', 'vision', 'embedding'][index % 4]! + const number = new Intl.NumberFormat('en-US', { + minimumIntegerDigits: 4, + }).format(index + 1) + + return { + value: `model-${index + 1}`, + label: `${provider} ${family} ${number}`, + meta: `${provider} provider · ${family}`, + icon: family === 'embedding' + ? 'i-ri-vector-triangle' + : family === 'vision' + ? 'i-ri-image-circle-line' + : family === 'reasoning' + ? 'i-ri-brain-line' + : 'i-ri-chat-1-line', + } +}) + +const sizeOptions: Option[] = providerOptions.slice(0, 3) +const defaultProvider = providerOptions[0]! +const disabledProvider = providerOptions[1]! +const defaultDataSource = dataSourceOptions[0]! +const defaultPopupDataSource = dataSourceOptions[1]! +const readOnlyDataSource = dataSourceOptions[2]! +const defaultTool = toolGroups[0]!.items[0]! +const defaultReviewers = [reviewerOptions[0]!, reviewerOptions[2]!] +const defaultTag = tagOptions[2]! + +const renderOptionItem = (option: Option, index?: number) => ( + + + {option.icon && } + + {option.label} + {option.meta && {option.meta}} + + + + +) + +const renderSimpleOptionItem = (option: Option, index?: number) => ( + + {option.label} + + +) + +const PopupSearchInput = ({ + label, + placeholder, +}: { + label: string + placeholder: string +}) => ( + + + + + +) + +const GroupedToolList = () => { + const groups = useComboboxFilteredItems() + + return ( + + {groups.map((group, groupIndex) => ( + + {groupIndex > 0 && } + {group.label} + + {(option: Option) => renderOptionItem(option)} + + + ))} + + ) +} + +const VirtualizedModelList = ({ + virtualizerRef, +}: { + virtualizerRef: RefObject +}) => { + const scrollRef = useRef(null) + const filteredItems = useComboboxFilteredItems