From 7ecbed3b0496087b85937545cf5119b65032ab16 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Fri, 22 May 2026 04:25:03 +0900 Subject: [PATCH] chore: add Type to test (#36454) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../controllers/console/auth/test_oauth.py | 37 ++++++----- .../console/workspace/test_tool_provider.py | 5 +- .../core/datasource/test_notion_provider.py | 66 ++++++++++--------- .../core/moderation/test_output_moderation.py | 34 +++++----- .../core/ops/test_ops_trace_manager.py | 6 +- .../unit_tests/services/controller_api.py | 15 +++-- .../unit_tests/services/test_file_service.py | 44 ++++++------- .../tasks/test_clean_dataset_task.py | 42 ++++++------ .../tasks/test_document_indexing_sync_task.py | 22 +++---- 9 files changed, 144 insertions(+), 127 deletions(-) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index a5ae83739c..3c496d1fc8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -45,7 +45,7 @@ class TestGetOAuthProviders: ) @patch("controllers.console.auth.oauth.dify_config") def test_should_configure_oauth_providers_correctly( - self, mock_config, app, github_config, google_config, expected_github, expected_google + self, mock_config, app: Flask, github_config, google_config, expected_github, expected_google ): mock_config.GITHUB_CLIENT_ID = github_config["id"] mock_config.GITHUB_CLIENT_SECRET = github_config["secret"] @@ -89,7 +89,7 @@ class TestOAuthLogin: self, mock_redirect, mock_get_providers, - resource, + resource: OAuthLogin, app: Flask, mock_oauth_provider, invite_token, @@ -114,7 +114,7 @@ class TestOAuthLogin: self, mock_redirect, mock_get_providers, - resource, + resource: OAuthLogin, app: Flask, mock_oauth_provider, ): @@ -136,7 +136,7 @@ class TestOAuthLogin: self, mock_redirect, mock_get_providers, - resource, + resource: OAuthLogin, app: Flask, mock_oauth_provider, ): @@ -212,7 +212,7 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -237,7 +237,9 @@ class TestOAuthCallback: ], ) @patch("controllers.console.auth.oauth.get_oauth_providers") - def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error): + def test_should_handle_oauth_exceptions( + self, mock_get_providers, resource: OAuthCallback, app: Flask, exception, expected_error + ): # Import the real requests module to create a proper exception import httpx @@ -265,7 +267,7 @@ class TestOAuthCallback: mock_register_service, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -310,7 +312,7 @@ class TestOAuthCallback: mock_config, mock_tenant_service, mock_account_service, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, account_status, @@ -349,7 +351,7 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -385,7 +387,7 @@ class TestOAuthCallback: mock_generate_account, mock_get_providers, mock_config, - resource, + resource: OAuthCallback, app: Flask, oauth_setup, ): @@ -460,7 +462,12 @@ class TestAccountGeneration: @patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback") @patch("controllers.console.auth.oauth.Account") def test_should_get_account_by_openid_or_email( - self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account + self, + mock_account_model, + mock_get_account, + flask_req_ctx_with_containers, + user_info: OAuthUserInfo, + mock_account, ): # Test OpenID found mock_account_model.get_by_openid.return_value = mock_account @@ -516,7 +523,7 @@ class TestAccountGeneration: mock_feature_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, mock_account, allow_register, existing_account, @@ -592,7 +599,7 @@ class TestAccountGeneration: mock_feature_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, ): mock_feature_service.get_system_features.return_value.is_allow_register = True mock_register_service.register.return_value = MagicMock() @@ -623,7 +630,7 @@ class TestAccountGeneration: mock_feature_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, ): mock_feature_service.get_system_features.return_value.is_allow_register = True mock_register_service.register.return_value = MagicMock() @@ -654,7 +661,7 @@ class TestAccountGeneration: mock_tenant_service, mock_get_account, app: Flask, - user_info, + user_info: OAuthUserInfo, mock_account, ): mock_get_account.return_value = mock_account diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index d944613886..b977a3eb7a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from flask.testing import FlaskClient from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -73,7 +74,9 @@ def client(flask_app_with_containers: Flask): @patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True) @patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True) @pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant") -def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client): +def test_create_mcp_provider_populates_tools( + mock_reconnect, mock_session, mock_current_account_with_tenant, client: FlaskClient +): # Arrange: reconnect returns tools immediately mock_reconnect.return_value = ReconnectResult( authed=True, diff --git a/api/tests/unit_tests/core/datasource/test_notion_provider.py b/api/tests/unit_tests/core/datasource/test_notion_provider.py index d21b9e471b..ecbd9691e9 100644 --- a/api/tests/unit_tests/core/datasource/test_notion_provider.py +++ b/api/tests/unit_tests/core/datasource/test_notion_provider.py @@ -183,7 +183,7 @@ class TestNotionExtractorPageRetrieval: } @patch("httpx.request") - def test_get_notion_block_data_simple_page(self, mock_request, extractor): + def test_get_notion_block_data_simple_page(self, mock_request, extractor: NotionExtractor): """Test retrieving simple page with basic blocks.""" # Arrange mock_data = { @@ -207,7 +207,7 @@ class TestNotionExtractorPageRetrieval: mock_request.assert_called_once() @patch("httpx.request") - def test_get_notion_block_data_with_headings(self, mock_request, extractor): + def test_get_notion_block_data_with_headings(self, mock_request, extractor: NotionExtractor): """Test retrieving page with heading blocks.""" # Arrange mock_data = { @@ -234,7 +234,7 @@ class TestNotionExtractorPageRetrieval: assert "### Sub-subtitle" in result[3] @patch("httpx.request") - def test_get_notion_block_data_with_pagination(self, mock_request, extractor): + def test_get_notion_block_data_with_pagination(self, mock_request, extractor: NotionExtractor): """Test retrieving page with paginated results.""" # Arrange first_page = { @@ -264,7 +264,7 @@ class TestNotionExtractorPageRetrieval: assert mock_request.call_count == 2 @patch("httpx.request") - def test_get_notion_block_data_with_nested_blocks(self, mock_request, extractor): + def test_get_notion_block_data_with_nested_blocks(self, mock_request, extractor: NotionExtractor): """Test retrieving page with nested block structure.""" # Arrange # First call returns parent blocks @@ -300,7 +300,7 @@ class TestNotionExtractorPageRetrieval: assert mock_request.call_count == 2 @patch("httpx.request") - def test_get_notion_block_data_error_handling(self, mock_request, extractor): + def test_get_notion_block_data_error_handling(self, mock_request, extractor: NotionExtractor): """Test error handling for failed API requests.""" # Arrange mock_request.return_value = self._create_mock_response({}, status_code=404) @@ -311,7 +311,7 @@ class TestNotionExtractorPageRetrieval: assert "Error fetching Notion block data" in str(exc_info.value) @patch("httpx.request") - def test_get_notion_block_data_invalid_response(self, mock_request, extractor): + def test_get_notion_block_data_invalid_response(self, mock_request, extractor: NotionExtractor): """Test handling of invalid API response structure.""" # Arrange mock_request.return_value = self._create_mock_response({"invalid": "structure"}) @@ -322,7 +322,7 @@ class TestNotionExtractorPageRetrieval: assert "Error fetching Notion block data" in str(exc_info.value) @patch("httpx.request") - def test_get_notion_block_data_http_error(self, mock_request, extractor): + def test_get_notion_block_data_http_error(self, mock_request, extractor: NotionExtractor): """Test handling of HTTP errors during request.""" # Arrange mock_request.side_effect = httpx.HTTPError("Network error") @@ -368,7 +368,7 @@ class TestNotionExtractorDatabaseRetrieval: } @patch("httpx.post") - def test_get_notion_database_data_simple(self, mock_post, extractor): + def test_get_notion_database_data_simple(self, mock_post, extractor: NotionExtractor): """Test retrieving simple database with basic properties.""" # Arrange mock_response = Mock() @@ -407,7 +407,7 @@ class TestNotionExtractorDatabaseRetrieval: assert "Status:Done" in content @patch("httpx.post") - def test_get_notion_database_data_with_pagination(self, mock_post, extractor): + def test_get_notion_database_data_with_pagination(self, mock_post, extractor: NotionExtractor): """Test retrieving database with paginated results.""" # Arrange first_response = Mock() @@ -441,7 +441,7 @@ class TestNotionExtractorDatabaseRetrieval: assert mock_post.call_count == 2 @patch("httpx.post") - def test_get_notion_database_data_multi_select(self, mock_post, extractor): + def test_get_notion_database_data_multi_select(self, mock_post, extractor: NotionExtractor): """Test database with multi_select property type.""" # Arrange mock_response = Mock() @@ -474,7 +474,7 @@ class TestNotionExtractorDatabaseRetrieval: assert "Tags:" in content @patch("httpx.post") - def test_get_notion_database_data_empty_properties(self, mock_post, extractor): + def test_get_notion_database_data_empty_properties(self, mock_post, extractor: NotionExtractor): """Test database with empty property values.""" # Arrange mock_response = Mock() @@ -504,7 +504,7 @@ class TestNotionExtractorDatabaseRetrieval: assert "Row Page URL:" in content @patch("httpx.post") - def test_get_notion_database_data_empty_results(self, mock_post, extractor): + def test_get_notion_database_data_empty_results(self, mock_post, extractor: NotionExtractor): """Test handling of empty database.""" # Arrange mock_response = Mock() @@ -523,7 +523,7 @@ class TestNotionExtractorDatabaseRetrieval: assert len(result) == 0 @patch("httpx.post") - def test_get_notion_database_data_missing_results(self, mock_post, extractor): + def test_get_notion_database_data_missing_results(self, mock_post, extractor: NotionExtractor): """Test handling of malformed API response.""" # Arrange mock_response = Mock() @@ -559,7 +559,7 @@ class TestNotionExtractorTableParsing: ) @patch("httpx.request") - def test_read_table_rows_simple(self, mock_request, extractor): + def test_read_table_rows_simple(self, mock_request, extractor: NotionExtractor): """Test reading simple table with headers and rows.""" # Arrange mock_data = { @@ -611,7 +611,7 @@ class TestNotionExtractorTableParsing: assert "| Bob | 25 |" in result @patch("httpx.request") - def test_read_table_rows_with_empty_cells(self, mock_request, extractor): + def test_read_table_rows_with_empty_cells(self, mock_request, extractor: NotionExtractor): """Test reading table with empty cells.""" # Arrange mock_data = { @@ -643,7 +643,7 @@ class TestNotionExtractorTableParsing: assert "Value1" in result @patch("httpx.request") - def test_read_table_rows_with_pagination(self, mock_request, extractor): + def test_read_table_rows_with_pagination(self, mock_request, extractor: NotionExtractor): """Test reading table with paginated results.""" # Arrange first_page = { @@ -960,7 +960,7 @@ class TestNotionExtractorReadBlock: ) @patch("httpx.request") - def test_read_block_with_indentation(self, mock_request, extractor): + def test_read_block_with_indentation(self, mock_request, extractor: NotionExtractor): """Test reading nested blocks with proper indentation.""" # Arrange mock_data = { @@ -990,7 +990,7 @@ class TestNotionExtractorReadBlock: assert "\t\tNested content" in result @patch("httpx.request") - def test_read_block_skip_child_page(self, mock_request, extractor): + def test_read_block_skip_child_page(self, mock_request, extractor: NotionExtractor): """Test that child_page blocks don't recurse.""" # Arrange mock_data = { @@ -1139,7 +1139,7 @@ class TestNotionExtractorAdvancedBlockTypes: } @patch("httpx.request") - def test_get_notion_block_data_with_list_blocks(self, mock_request, extractor): + def test_get_notion_block_data_with_list_blocks(self, mock_request, extractor: NotionExtractor): """Test retrieving page with bulleted and numbered list items. Both list types should be extracted with their content. @@ -1165,7 +1165,7 @@ class TestNotionExtractorAdvancedBlockTypes: assert "Numbered item" in result[1] @patch("httpx.request") - def test_get_notion_block_data_with_special_blocks(self, mock_request, extractor): + def test_get_notion_block_data_with_special_blocks(self, mock_request, extractor: NotionExtractor): """Test retrieving page with code, quote, and callout blocks. Special block types should preserve their content correctly. @@ -1193,7 +1193,7 @@ class TestNotionExtractorAdvancedBlockTypes: assert "Important note" in result[2] @patch("httpx.request") - def test_get_notion_block_data_with_toggle_block(self, mock_request, extractor): + def test_get_notion_block_data_with_toggle_block(self, mock_request, extractor: NotionExtractor): """Test retrieving page with toggle block containing children. Toggle blocks can have nested content that should be extracted. @@ -1229,7 +1229,7 @@ class TestNotionExtractorAdvancedBlockTypes: assert "Hidden content" in result[0] @patch("httpx.request") - def test_get_notion_block_data_mixed_block_types(self, mock_request, extractor): + def test_get_notion_block_data_mixed_block_types(self, mock_request, extractor: NotionExtractor): """Test retrieving page with mixed block types. Real Notion pages contain various block types mixed together. @@ -1308,7 +1308,7 @@ class TestNotionExtractorDatabaseAdvanced: } @patch("httpx.post") - def test_get_notion_database_data_with_various_property_types(self, mock_post, extractor): + def test_get_notion_database_data_with_various_property_types(self, mock_post, extractor: NotionExtractor): """Test database with multiple property types. Tests date, number, checkbox, URL, email, phone, and status properties. @@ -1354,7 +1354,7 @@ class TestNotionExtractorDatabaseAdvanced: assert "Status:Active" in content @patch("httpx.post") - def test_get_notion_database_data_large_pagination(self, mock_post, extractor): + def test_get_notion_database_data_large_pagination(self, mock_post, extractor: NotionExtractor): """Test database with multiple pages of results. Large databases require multiple API calls with cursor-based pagination. @@ -1415,7 +1415,7 @@ class TestNotionExtractorDatabaseAdvanced: assert mock_post.call_count == 3 @patch("httpx.post") - def test_get_notion_database_data_with_rich_text_property(self, mock_post, extractor): + def test_get_notion_database_data_with_rich_text_property(self, mock_post, extractor: NotionExtractor): """Test database with rich_text property type. Rich text properties can contain formatted text and should be extracted. @@ -1486,7 +1486,9 @@ class TestNotionExtractorErrorScenarios: ], ) @patch("httpx.request") - def test_get_notion_block_data_network_errors(self, mock_request, extractor, error_type, error_value): + def test_get_notion_block_data_network_errors( + self, mock_request, extractor: NotionExtractor, error_type, error_value + ): """Test handling of various network errors. Network issues (timeouts, connection failures) should raise appropriate errors. @@ -1509,7 +1511,9 @@ class TestNotionExtractorErrorScenarios: ], ) @patch("httpx.request") - def test_get_notion_block_data_http_status_errors(self, mock_request, extractor, status_code, description): + def test_get_notion_block_data_http_status_errors( + self, mock_request, extractor: NotionExtractor, status_code, description + ): """Test handling of various HTTP status errors. Different HTTP error codes (401, 403, 404, 429) should be handled appropriately. @@ -1534,7 +1538,9 @@ class TestNotionExtractorErrorScenarios: ], ) @patch("httpx.request") - def test_get_notion_block_data_malformed_responses(self, mock_request, extractor, response_data, description): + def test_get_notion_block_data_malformed_responses( + self, mock_request, extractor: NotionExtractor, response_data, description + ): """Test handling of malformed API responses. Various malformed responses should be handled gracefully. @@ -1551,7 +1557,7 @@ class TestNotionExtractorErrorScenarios: assert "Error fetching Notion block data" in str(exc_info.value) @patch("httpx.post") - def test_get_notion_database_data_with_query_filter(self, mock_post, extractor): + def test_get_notion_database_data_with_query_filter(self, mock_post, extractor: NotionExtractor): """Test database query with custom filter. Databases can be queried with filters to retrieve specific rows. @@ -1618,7 +1624,7 @@ class TestNotionExtractorTableAdvanced: ) @patch("httpx.request") - def test_read_table_rows_with_many_columns(self, mock_request, extractor): + def test_read_table_rows_with_many_columns(self, mock_request, extractor: NotionExtractor): """Test reading table with many columns. Tables can have numerous columns; all should be extracted correctly. diff --git a/api/tests/unit_tests/core/moderation/test_output_moderation.py b/api/tests/unit_tests/core/moderation/test_output_moderation.py index 36a80cc76c..ce384c4c13 100644 --- a/api/tests/unit_tests/core/moderation/test_output_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_output_moderation.py @@ -19,22 +19,22 @@ class TestOutputModeration: return ModerationRule(type="keywords", config={"keywords": "badword"}) @pytest.fixture - def output_moderation(self, mock_queue_manager, moderation_rule): + def output_moderation(self, mock_queue_manager, moderation_rule: ModerationRule): return OutputModeration( tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager ) - def test_should_direct_output(self, output_moderation): + def test_should_direct_output(self, output_moderation: OutputModeration): assert output_moderation.should_direct_output() is False output_moderation.final_output = "blocked" assert output_moderation.should_direct_output() is True - def test_get_final_output(self, output_moderation): + def test_get_final_output(self, output_moderation: OutputModeration): assert output_moderation.get_final_output() == "" output_moderation.final_output = "blocked" assert output_moderation.get_final_output() == "blocked" - def test_append_new_token(self, output_moderation): + def test_append_new_token(self, output_moderation: OutputModeration): with patch.object(OutputModeration, "start_thread") as mock_start: output_moderation.append_new_token("hello") assert output_moderation.buffer == "hello" @@ -45,7 +45,7 @@ class TestOutputModeration: assert output_moderation.buffer == "hello world" assert mock_start.call_count == 1 - def test_moderation_completion_no_flag(self, output_moderation): + def test_moderation_completion_no_flag(self, output_moderation: OutputModeration): with patch.object(OutputModeration, "moderation") as mock_moderation: mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) @@ -55,7 +55,7 @@ class TestOutputModeration: assert flagged is False assert output_moderation.is_final_chunk is True - def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager): + def test_moderation_completion_flagged_direct_output(self, output_moderation: OutputModeration, mock_queue_manager): with patch.object(OutputModeration, "moderation") as mock_moderation: mock_moderation.return_value = ModerationOutputsResult( flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" @@ -71,7 +71,7 @@ class TestOutputModeration: assert args[0].text == "preset" assert args[1] == PublishFrom.TASK_PIPELINE - def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager): + def test_moderation_completion_flagged_overridden(self, output_moderation: OutputModeration, mock_queue_manager): with patch.object(OutputModeration, "moderation") as mock_moderation: mock_moderation.return_value = ModerationOutputsResult( flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content" @@ -85,7 +85,7 @@ class TestOutputModeration: args, _ = mock_queue_manager.publish.call_args assert args[0].text == "masked content" - def test_start_thread(self, output_moderation): + def test_start_thread(self, output_moderation: OutputModeration): mock_app = MagicMock(spec=Flask) with patch("core.moderation.output_moderation.current_app") as mock_current_app: mock_current_app._get_current_object = MagicMock(return_value=mock_app) @@ -99,7 +99,7 @@ class TestOutputModeration: mock_thread_class.assert_called_once() mock_thread_instance.start.assert_called_once() - def test_stop_thread(self, output_moderation): + def test_stop_thread(self, output_moderation: OutputModeration): mock_thread = MagicMock() mock_thread.is_alive.return_value = True output_moderation.thread = mock_thread @@ -113,7 +113,7 @@ class TestOutputModeration: assert output_moderation.thread_running is True @patch("core.moderation.output_moderation.ModerationFactory") - def test_moderation_success(self, mock_factory_class, output_moderation): + def test_moderation_success(self, mock_factory_class, output_moderation: OutputModeration): mock_factory = mock_factory_class.return_value mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) mock_factory.moderation_for_outputs.return_value = mock_result @@ -126,13 +126,13 @@ class TestOutputModeration: ) @patch("core.moderation.output_moderation.ModerationFactory") - def test_moderation_exception(self, mock_factory_class, output_moderation): + def test_moderation_exception(self, mock_factory_class, output_moderation: OutputModeration): mock_factory_class.side_effect = Exception("error") result = output_moderation.moderation("tenant", "app", "buffer") assert result is None - def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager): + def test_worker_loop_and_exit(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) # Test exit on thread_running=False @@ -140,7 +140,7 @@ class TestOutputModeration: output_moderation.worker(mock_app, 10) # Should exit immediately - def test_worker_no_flag(self, output_moderation): + def test_worker_no_flag(self, output_moderation: OutputModeration): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: @@ -160,7 +160,7 @@ class TestOutputModeration: assert mock_moderation.called - def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager): + def test_worker_flagged_direct_output(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: @@ -177,7 +177,7 @@ class TestOutputModeration: mock_queue_manager.publish.assert_called_once() # It breaks on DIRECT_OUTPUT - def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager): + def test_worker_flagged_overridden(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: @@ -199,7 +199,7 @@ class TestOutputModeration: args, _ = mock_queue_manager.publish.call_args assert args[0].text == "masked" - def test_worker_chunk_too_small(self, output_moderation): + def test_worker_chunk_too_small(self, output_moderation: OutputModeration): mock_app = MagicMock(spec=Flask) with patch("time.sleep") as mock_sleep: # chunk_length < buffer_size and not is_final_chunk @@ -215,7 +215,7 @@ class TestOutputModeration: mock_sleep.assert_called_once_with(1) - def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager): + def test_worker_empty_not_flagged(self, output_moderation: OutputModeration, mock_queue_manager): mock_app = MagicMock(spec=Flask) with patch.object(OutputModeration, "moderation") as mock_moderation: # Return None (exception or no rule) diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 33a3293682..704f5d362c 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -262,7 +262,7 @@ def workflow_repo_fixture(monkeypatch: pytest.MonkeyPatch): @pytest.fixture -def trace_task_message(monkeypatch, mock_db): +def trace_task_message(monkeypatch: pytest.MonkeyPatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) @@ -353,7 +353,7 @@ def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch: pytest.Mo assert OpsTraceManager.get_ops_trace_instance("app-id") is None -def test_get_ops_trace_instance_success(monkeypatch, mock_db): +def test_get_ops_trace_instance_success(monkeypatch: pytest.MonkeyPatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) mock_db.get.return_value = app monkeypatch.setattr( @@ -497,7 +497,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): assert result.documents == [{"doc": "value"}] -def test_trace_task_tool_trace(monkeypatch, mock_db): +def test_trace_task_tool_trace(monkeypatch: pytest.MonkeyPatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) configure_db_scalar(mock_db, message_file=FakeMessageFile()) diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py index e7f7cabecd..ea60b94b61 100644 --- a/api/tests/unit_tests/services/controller_api.py +++ b/api/tests/unit_tests/services/controller_api.py @@ -87,6 +87,7 @@ from uuid import uuid4 import pytest from flask import Flask +from flask.testing import FlaskClient from flask_restx import Api from controllers.console.datasets.datasets import DatasetApi, DatasetListApi @@ -339,7 +340,7 @@ class TestDatasetListApi: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_get_datasets_success(self, client, mock_current_user): + def test_get_datasets_success(self, client: FlaskClient, mock_current_user): """ Test successful retrieval of dataset list. @@ -380,7 +381,7 @@ class TestDatasetListApi: # Verify service was called mock_get_datasets.assert_called_once() - def test_get_datasets_with_search(self, client, mock_current_user): + def test_get_datasets_with_search(self, client: FlaskClient, mock_current_user): """ Test dataset listing with search keyword. @@ -410,7 +411,7 @@ class TestDatasetListApi: call_args = mock_get_datasets.call_args assert call_args[1]["search"] == search_keyword - def test_get_datasets_with_pagination(self, client, mock_current_user): + def test_get_datasets_with_pagination(self, client: FlaskClient, mock_current_user): """ Test dataset listing with pagination parameters. @@ -495,7 +496,7 @@ class TestDatasetApiGet: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_get_dataset_success(self, client, mock_current_user): + def test_get_dataset_success(self, client: FlaskClient, mock_current_user): """ Test successful retrieval of a single dataset. @@ -533,7 +534,7 @@ class TestDatasetApiGet: mock_get_dataset.assert_called_once_with(dataset_id) mock_check_perm.assert_called_once() - def test_get_dataset_not_found(self, client, mock_current_user): + def test_get_dataset_not_found(self, client: FlaskClient, mock_current_user): """ Test error handling when dataset is not found. @@ -611,7 +612,7 @@ class TestDatasetApiCreate: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_create_dataset_success(self, client, mock_current_user): + def test_create_dataset_success(self, client: FlaskClient, mock_current_user): """ Test successful creation of a dataset. @@ -706,7 +707,7 @@ class TestHitTestingApi: mock_get_user.return_value = (mock_user, mock_tenant_id) yield mock_get_user - def test_hit_testing_success(self, client, mock_current_user): + def test_hit_testing_success(self, client: FlaskClient, mock_current_user): """ Test successful hit testing operation. diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index 69bd194a68..2e6ca7dbb9 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -52,7 +52,7 @@ class TestFileService: @patch("services.file_service.extract_tenant_id") @patch("services.file_service.file_helpers.get_signed_file_url") def test_upload_file_success( - self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session + self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service: FileService, mock_db_session ): # Setup mock_tenant_id.return_value = "tenant_id" @@ -88,7 +88,7 @@ class TestFileService: with pytest.raises(ValueError, match="Filename contains invalid characters"): file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock()) - def test_upload_file_long_filename(self, file_service, mock_db_session): + def test_upload_file_long_filename(self, file_service: FileService, mock_db_session): # Setup long_name = "a" * 210 + ".txt" user = MagicMock(spec=Account) @@ -124,7 +124,7 @@ class TestFileService: with pytest.raises(FileTooLargeError): file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock()) - def test_upload_file_end_user(self, file_service, mock_db_session): + def test_upload_file_end_user(self, file_service: FileService, mock_db_session): user = MagicMock(spec=EndUser) user.id = "end_user_id" @@ -160,7 +160,7 @@ class TestFileService: assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False - def test_get_file_base64_success(self, file_service, mock_db_session): + def test_get_file_base64_success(self, file_service: FileService, mock_db_session): # Setup upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" @@ -177,12 +177,12 @@ class TestFileService: assert result == base64.b64encode(b"test content").decode() mock_storage.load_once.assert_called_once_with("test_key") - def test_get_file_base64_not_found(self, file_service, mock_db_session): + def test_get_file_base64_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_base64("non_existent") - def test_upload_text_success(self, file_service, mock_db_session): + def test_upload_text_success(self, file_service: FileService, mock_db_session): # Setup text = "sample text" text_name = "test.txt" @@ -204,13 +204,13 @@ class TestFileService: mock_db_session.add.assert_called_once() mock_db_session.commit.assert_called_once() - def test_upload_text_long_name(self, file_service, mock_db_session): + def test_upload_text_long_name(self, file_service: FileService, mock_db_session): long_name = "a" * 210 with patch("services.file_service.storage"): result = file_service.upload_text("text", long_name, "user", "tenant") assert len(result.name) == 200 - def test_get_file_preview_success(self, file_service, mock_db_session): + def test_get_file_preview_success(self, file_service: FileService, mock_db_session): # Setup upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" @@ -226,12 +226,12 @@ class TestFileService: # Assert assert result == "Extracted text content" - def test_get_file_preview_not_found(self, file_service, mock_db_session): + def test_get_file_preview_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_preview("non_existent", "tenant_id") - def test_get_file_preview_unsupported_type(self, file_service, mock_db_session): + def test_get_file_preview_unsupported_type(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "exe" @@ -239,7 +239,7 @@ class TestFileService: with pytest.raises(UnsupportedFileTypeError): file_service.get_file_preview("file_id", "tenant_id") - def test_get_image_preview_success(self, file_service, mock_db_session): + def test_get_image_preview_success(self, file_service: FileService, mock_db_session): # Setup upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" @@ -268,14 +268,14 @@ class TestFileService: with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_image_preview("file_id", "ts", "nonce", "sign") - def test_get_image_preview_not_found(self, file_service, mock_db_session): + def test_get_image_preview_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_image_preview("file_id", "ts", "nonce", "sign") - def test_get_image_preview_unsupported_type(self, file_service, mock_db_session): + def test_get_image_preview_unsupported_type(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" @@ -285,7 +285,7 @@ class TestFileService: with pytest.raises(UnsupportedFileTypeError): file_service.get_image_preview("file_id", "ts", "nonce", "sign") - def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session): + def test_get_file_generator_by_file_id_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" @@ -308,14 +308,14 @@ class TestFileService: with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") - def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): + def test_get_file_generator_by_file_id_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") - def test_get_public_image_preview_success(self, file_service, mock_db_session): + def test_get_public_image_preview_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "png" @@ -329,12 +329,12 @@ class TestFileService: assert gen == b"image content" assert mime == "image/png" - def test_get_public_image_preview_not_found(self, file_service, mock_db_session): + def test_get_public_image_preview_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_public_image_preview("file_id") - def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session): + def test_get_public_image_preview_unsupported_type(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" @@ -342,7 +342,7 @@ class TestFileService: with pytest.raises(UnsupportedFileTypeError): file_service.get_public_image_preview("file_id") - def test_get_file_content_success(self, file_service, mock_db_session): + def test_get_file_content_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" @@ -353,12 +353,12 @@ class TestFileService: result = file_service.get_file_content("file_id") assert result == "hello world" - def test_get_file_content_not_found(self, file_service, mock_db_session): + def test_get_file_content_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_content("file_id") - def test_delete_file_success(self, file_service, mock_db_session): + def test_delete_file_success(self, file_service: FileService, mock_db_session): upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" @@ -370,7 +370,7 @@ class TestFileService: mock_storage.delete.assert_called_once_with("key") mock_db_session.delete.assert_called_once_with(upload_file) - def test_delete_file_not_found(self, file_service, mock_db_session): + def test_delete_file_not_found(self, file_service: FileService, mock_db_session): mock_db_session.scalar.return_value = None file_service.delete_file("file_id") # Should return without doing anything diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index b4332334ab..7ce897eb02 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -151,9 +151,9 @@ class TestErrorHandling: def test_clean_dataset_task_rollback_failure_still_closes_session( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -198,9 +198,9 @@ class TestPipelineAndWorkflowDeletion: def test_clean_dataset_task_with_pipeline_id( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, pipeline_id, mock_db_session, mock_storage, @@ -231,9 +231,9 @@ class TestPipelineAndWorkflowDeletion: def test_clean_dataset_task_without_pipeline_id( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -271,9 +271,9 @@ class TestSegmentAttachmentCleanup: def test_clean_dataset_task_with_attachments( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -321,9 +321,9 @@ class TestSegmentAttachmentCleanup: def test_clean_dataset_task_attachment_storage_failure( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -375,9 +375,9 @@ class TestEdgeCases: def test_clean_dataset_task_session_always_closed( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, @@ -413,9 +413,9 @@ class TestIndexProcessorParameters: def test_clean_dataset_task_passes_correct_parameters_to_index_processor( self, - dataset_id, - tenant_id, - collection_binding_id, + dataset_id: str, + tenant_id: str, + collection_binding_id: str, mock_db_session, mock_storage, mock_index_processor_factory, diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index 41d3068a10..e5782899e3 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -124,10 +124,10 @@ class TestDocumentIndexingSyncTaskCollaboratorParams: mock_datasource_provider_service, mock_notion_extractor, mock_document, - dataset_id, - document_id, - notion_workspace_id, - notion_page_id, + dataset_id: str, + document_id: str, + notion_workspace_id: str, + notion_page_id: str, ): """Test that NotionExtractor is initialized with expected arguments.""" # Arrange @@ -151,9 +151,9 @@ class TestDocumentIndexingSyncTaskCollaboratorParams: mock_datasource_provider_service, mock_notion_extractor, mock_document, - dataset_id, - document_id, - credential_id, + dataset_id: str, + document_id: str, + credential_id: str, ): """Test that datasource credentials are requested with expected identifiers.""" # Arrange @@ -176,8 +176,8 @@ class TestDocumentIndexingSyncTaskCollaboratorParams: mock_datasource_provider_service, mock_notion_extractor, mock_document, - dataset_id, - document_id, + dataset_id: str, + document_id: str, ): """Test that missing credential_id is forwarded as None.""" # Arrange @@ -212,8 +212,8 @@ class TestDataSourceInfoSerialization: self, mock_document, mock_dataset, - dataset_id, - document_id, + dataset_id: str, + document_id: str, ): """data_source_info must be serialized with json.dumps before DB write.""" with (