chore: add Type to test (#36454)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-05-22 04:25:03 +09:00 committed by GitHub
parent 5b58defd62
commit 7ecbed3b04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 144 additions and 127 deletions

View File

@ -45,7 +45,7 @@ class TestGetOAuthProviders:
)
@patch("controllers.console.auth.oauth.dify_config")
def test_should_configure_oauth_providers_correctly(
self, mock_config, app, github_config, google_config, expected_github, expected_google
self, mock_config, app: Flask, github_config, google_config, expected_github, expected_google
):
mock_config.GITHUB_CLIENT_ID = github_config["id"]
mock_config.GITHUB_CLIENT_SECRET = github_config["secret"]
@ -89,7 +89,7 @@ class TestOAuthLogin:
self,
mock_redirect,
mock_get_providers,
resource,
resource: OAuthLogin,
app: Flask,
mock_oauth_provider,
invite_token,
@ -114,7 +114,7 @@ class TestOAuthLogin:
self,
mock_redirect,
mock_get_providers,
resource,
resource: OAuthLogin,
app: Flask,
mock_oauth_provider,
):
@ -136,7 +136,7 @@ class TestOAuthLogin:
self,
mock_redirect,
mock_get_providers,
resource,
resource: OAuthLogin,
app: Flask,
mock_oauth_provider,
):
@ -212,7 +212,7 @@ class TestOAuthCallback:
mock_generate_account,
mock_get_providers,
mock_config,
resource,
resource: OAuthCallback,
app: Flask,
oauth_setup,
):
@ -237,7 +237,9 @@ class TestOAuthCallback:
],
)
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_should_handle_oauth_exceptions(self, mock_get_providers, resource, app, exception, expected_error):
def test_should_handle_oauth_exceptions(
self, mock_get_providers, resource: OAuthCallback, app: Flask, exception, expected_error
):
# Import the real requests module to create a proper exception
import httpx
@ -265,7 +267,7 @@ class TestOAuthCallback:
mock_register_service,
mock_get_providers,
mock_config,
resource,
resource: OAuthCallback,
app: Flask,
oauth_setup,
):
@ -310,7 +312,7 @@ class TestOAuthCallback:
mock_config,
mock_tenant_service,
mock_account_service,
resource,
resource: OAuthCallback,
app: Flask,
oauth_setup,
account_status,
@ -349,7 +351,7 @@ class TestOAuthCallback:
mock_generate_account,
mock_get_providers,
mock_config,
resource,
resource: OAuthCallback,
app: Flask,
oauth_setup,
):
@ -385,7 +387,7 @@ class TestOAuthCallback:
mock_generate_account,
mock_get_providers,
mock_config,
resource,
resource: OAuthCallback,
app: Flask,
oauth_setup,
):
@ -460,7 +462,12 @@ class TestAccountGeneration:
@patch("controllers.console.auth.oauth.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.oauth.Account")
def test_should_get_account_by_openid_or_email(
self, mock_account_model, mock_get_account, flask_req_ctx_with_containers, user_info, mock_account
self,
mock_account_model,
mock_get_account,
flask_req_ctx_with_containers,
user_info: OAuthUserInfo,
mock_account,
):
# Test OpenID found
mock_account_model.get_by_openid.return_value = mock_account
@ -516,7 +523,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
user_info: OAuthUserInfo,
mock_account,
allow_register,
existing_account,
@ -592,7 +599,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
user_info: OAuthUserInfo,
):
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
@ -623,7 +630,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
user_info: OAuthUserInfo,
):
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
@ -654,7 +661,7 @@ class TestAccountGeneration:
mock_tenant_service,
mock_get_account,
app: Flask,
user_info,
user_info: OAuthUserInfo,
mock_account,
):
mock_get_account.return_value = mock_account

View File

@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask.testing import FlaskClient
from werkzeug.exceptions import Forbidden
from controllers.console.workspace.tool_providers import (
@ -73,7 +74,9 @@ def client(flask_app_with_containers: Flask):
@patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True)
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
def test_create_mcp_provider_populates_tools(
mock_reconnect, mock_session, mock_current_account_with_tenant, client: FlaskClient
):
# Arrange: reconnect returns tools immediately
mock_reconnect.return_value = ReconnectResult(
authed=True,

View File

@ -183,7 +183,7 @@ class TestNotionExtractorPageRetrieval:
}
@patch("httpx.request")
def test_get_notion_block_data_simple_page(self, mock_request, extractor):
def test_get_notion_block_data_simple_page(self, mock_request, extractor: NotionExtractor):
"""Test retrieving simple page with basic blocks."""
# Arrange
mock_data = {
@ -207,7 +207,7 @@ class TestNotionExtractorPageRetrieval:
mock_request.assert_called_once()
@patch("httpx.request")
def test_get_notion_block_data_with_headings(self, mock_request, extractor):
def test_get_notion_block_data_with_headings(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with heading blocks."""
# Arrange
mock_data = {
@ -234,7 +234,7 @@ class TestNotionExtractorPageRetrieval:
assert "### Sub-subtitle" in result[3]
@patch("httpx.request")
def test_get_notion_block_data_with_pagination(self, mock_request, extractor):
def test_get_notion_block_data_with_pagination(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with paginated results."""
# Arrange
first_page = {
@ -264,7 +264,7 @@ class TestNotionExtractorPageRetrieval:
assert mock_request.call_count == 2
@patch("httpx.request")
def test_get_notion_block_data_with_nested_blocks(self, mock_request, extractor):
def test_get_notion_block_data_with_nested_blocks(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with nested block structure."""
# Arrange
# First call returns parent blocks
@ -300,7 +300,7 @@ class TestNotionExtractorPageRetrieval:
assert mock_request.call_count == 2
@patch("httpx.request")
def test_get_notion_block_data_error_handling(self, mock_request, extractor):
def test_get_notion_block_data_error_handling(self, mock_request, extractor: NotionExtractor):
"""Test error handling for failed API requests."""
# Arrange
mock_request.return_value = self._create_mock_response({}, status_code=404)
@ -311,7 +311,7 @@ class TestNotionExtractorPageRetrieval:
assert "Error fetching Notion block data" in str(exc_info.value)
@patch("httpx.request")
def test_get_notion_block_data_invalid_response(self, mock_request, extractor):
def test_get_notion_block_data_invalid_response(self, mock_request, extractor: NotionExtractor):
"""Test handling of invalid API response structure."""
# Arrange
mock_request.return_value = self._create_mock_response({"invalid": "structure"})
@ -322,7 +322,7 @@ class TestNotionExtractorPageRetrieval:
assert "Error fetching Notion block data" in str(exc_info.value)
@patch("httpx.request")
def test_get_notion_block_data_http_error(self, mock_request, extractor):
def test_get_notion_block_data_http_error(self, mock_request, extractor: NotionExtractor):
"""Test handling of HTTP errors during request."""
# Arrange
mock_request.side_effect = httpx.HTTPError("Network error")
@ -368,7 +368,7 @@ class TestNotionExtractorDatabaseRetrieval:
}
@patch("httpx.post")
def test_get_notion_database_data_simple(self, mock_post, extractor):
def test_get_notion_database_data_simple(self, mock_post, extractor: NotionExtractor):
"""Test retrieving simple database with basic properties."""
# Arrange
mock_response = Mock()
@ -407,7 +407,7 @@ class TestNotionExtractorDatabaseRetrieval:
assert "Status:Done" in content
@patch("httpx.post")
def test_get_notion_database_data_with_pagination(self, mock_post, extractor):
def test_get_notion_database_data_with_pagination(self, mock_post, extractor: NotionExtractor):
"""Test retrieving database with paginated results."""
# Arrange
first_response = Mock()
@ -441,7 +441,7 @@ class TestNotionExtractorDatabaseRetrieval:
assert mock_post.call_count == 2
@patch("httpx.post")
def test_get_notion_database_data_multi_select(self, mock_post, extractor):
def test_get_notion_database_data_multi_select(self, mock_post, extractor: NotionExtractor):
"""Test database with multi_select property type."""
# Arrange
mock_response = Mock()
@ -474,7 +474,7 @@ class TestNotionExtractorDatabaseRetrieval:
assert "Tags:" in content
@patch("httpx.post")
def test_get_notion_database_data_empty_properties(self, mock_post, extractor):
def test_get_notion_database_data_empty_properties(self, mock_post, extractor: NotionExtractor):
"""Test database with empty property values."""
# Arrange
mock_response = Mock()
@ -504,7 +504,7 @@ class TestNotionExtractorDatabaseRetrieval:
assert "Row Page URL:" in content
@patch("httpx.post")
def test_get_notion_database_data_empty_results(self, mock_post, extractor):
def test_get_notion_database_data_empty_results(self, mock_post, extractor: NotionExtractor):
"""Test handling of empty database."""
# Arrange
mock_response = Mock()
@ -523,7 +523,7 @@ class TestNotionExtractorDatabaseRetrieval:
assert len(result) == 0
@patch("httpx.post")
def test_get_notion_database_data_missing_results(self, mock_post, extractor):
def test_get_notion_database_data_missing_results(self, mock_post, extractor: NotionExtractor):
"""Test handling of malformed API response."""
# Arrange
mock_response = Mock()
@ -559,7 +559,7 @@ class TestNotionExtractorTableParsing:
)
@patch("httpx.request")
def test_read_table_rows_simple(self, mock_request, extractor):
def test_read_table_rows_simple(self, mock_request, extractor: NotionExtractor):
"""Test reading simple table with headers and rows."""
# Arrange
mock_data = {
@ -611,7 +611,7 @@ class TestNotionExtractorTableParsing:
assert "| Bob | 25 |" in result
@patch("httpx.request")
def test_read_table_rows_with_empty_cells(self, mock_request, extractor):
def test_read_table_rows_with_empty_cells(self, mock_request, extractor: NotionExtractor):
"""Test reading table with empty cells."""
# Arrange
mock_data = {
@ -643,7 +643,7 @@ class TestNotionExtractorTableParsing:
assert "Value1" in result
@patch("httpx.request")
def test_read_table_rows_with_pagination(self, mock_request, extractor):
def test_read_table_rows_with_pagination(self, mock_request, extractor: NotionExtractor):
"""Test reading table with paginated results."""
# Arrange
first_page = {
@ -960,7 +960,7 @@ class TestNotionExtractorReadBlock:
)
@patch("httpx.request")
def test_read_block_with_indentation(self, mock_request, extractor):
def test_read_block_with_indentation(self, mock_request, extractor: NotionExtractor):
"""Test reading nested blocks with proper indentation."""
# Arrange
mock_data = {
@ -990,7 +990,7 @@ class TestNotionExtractorReadBlock:
assert "\t\tNested content" in result
@patch("httpx.request")
def test_read_block_skip_child_page(self, mock_request, extractor):
def test_read_block_skip_child_page(self, mock_request, extractor: NotionExtractor):
"""Test that child_page blocks don't recurse."""
# Arrange
mock_data = {
@ -1139,7 +1139,7 @@ class TestNotionExtractorAdvancedBlockTypes:
}
@patch("httpx.request")
def test_get_notion_block_data_with_list_blocks(self, mock_request, extractor):
def test_get_notion_block_data_with_list_blocks(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with bulleted and numbered list items.
Both list types should be extracted with their content.
@ -1165,7 +1165,7 @@ class TestNotionExtractorAdvancedBlockTypes:
assert "Numbered item" in result[1]
@patch("httpx.request")
def test_get_notion_block_data_with_special_blocks(self, mock_request, extractor):
def test_get_notion_block_data_with_special_blocks(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with code, quote, and callout blocks.
Special block types should preserve their content correctly.
@ -1193,7 +1193,7 @@ class TestNotionExtractorAdvancedBlockTypes:
assert "Important note" in result[2]
@patch("httpx.request")
def test_get_notion_block_data_with_toggle_block(self, mock_request, extractor):
def test_get_notion_block_data_with_toggle_block(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with toggle block containing children.
Toggle blocks can have nested content that should be extracted.
@ -1229,7 +1229,7 @@ class TestNotionExtractorAdvancedBlockTypes:
assert "Hidden content" in result[0]
@patch("httpx.request")
def test_get_notion_block_data_mixed_block_types(self, mock_request, extractor):
def test_get_notion_block_data_mixed_block_types(self, mock_request, extractor: NotionExtractor):
"""Test retrieving page with mixed block types.
Real Notion pages contain various block types mixed together.
@ -1308,7 +1308,7 @@ class TestNotionExtractorDatabaseAdvanced:
}
@patch("httpx.post")
def test_get_notion_database_data_with_various_property_types(self, mock_post, extractor):
def test_get_notion_database_data_with_various_property_types(self, mock_post, extractor: NotionExtractor):
"""Test database with multiple property types.
Tests date, number, checkbox, URL, email, phone, and status properties.
@ -1354,7 +1354,7 @@ class TestNotionExtractorDatabaseAdvanced:
assert "Status:Active" in content
@patch("httpx.post")
def test_get_notion_database_data_large_pagination(self, mock_post, extractor):
def test_get_notion_database_data_large_pagination(self, mock_post, extractor: NotionExtractor):
"""Test database with multiple pages of results.
Large databases require multiple API calls with cursor-based pagination.
@ -1415,7 +1415,7 @@ class TestNotionExtractorDatabaseAdvanced:
assert mock_post.call_count == 3
@patch("httpx.post")
def test_get_notion_database_data_with_rich_text_property(self, mock_post, extractor):
def test_get_notion_database_data_with_rich_text_property(self, mock_post, extractor: NotionExtractor):
"""Test database with rich_text property type.
Rich text properties can contain formatted text and should be extracted.
@ -1486,7 +1486,9 @@ class TestNotionExtractorErrorScenarios:
],
)
@patch("httpx.request")
def test_get_notion_block_data_network_errors(self, mock_request, extractor, error_type, error_value):
def test_get_notion_block_data_network_errors(
self, mock_request, extractor: NotionExtractor, error_type, error_value
):
"""Test handling of various network errors.
Network issues (timeouts, connection failures) should raise appropriate errors.
@ -1509,7 +1511,9 @@ class TestNotionExtractorErrorScenarios:
],
)
@patch("httpx.request")
def test_get_notion_block_data_http_status_errors(self, mock_request, extractor, status_code, description):
def test_get_notion_block_data_http_status_errors(
self, mock_request, extractor: NotionExtractor, status_code, description
):
"""Test handling of various HTTP status errors.
Different HTTP error codes (401, 403, 404, 429) should be handled appropriately.
@ -1534,7 +1538,9 @@ class TestNotionExtractorErrorScenarios:
],
)
@patch("httpx.request")
def test_get_notion_block_data_malformed_responses(self, mock_request, extractor, response_data, description):
def test_get_notion_block_data_malformed_responses(
self, mock_request, extractor: NotionExtractor, response_data, description
):
"""Test handling of malformed API responses.
Various malformed responses should be handled gracefully.
@ -1551,7 +1557,7 @@ class TestNotionExtractorErrorScenarios:
assert "Error fetching Notion block data" in str(exc_info.value)
@patch("httpx.post")
def test_get_notion_database_data_with_query_filter(self, mock_post, extractor):
def test_get_notion_database_data_with_query_filter(self, mock_post, extractor: NotionExtractor):
"""Test database query with custom filter.
Databases can be queried with filters to retrieve specific rows.
@ -1618,7 +1624,7 @@ class TestNotionExtractorTableAdvanced:
)
@patch("httpx.request")
def test_read_table_rows_with_many_columns(self, mock_request, extractor):
def test_read_table_rows_with_many_columns(self, mock_request, extractor: NotionExtractor):
"""Test reading table with many columns.
Tables can have numerous columns; all should be extracted correctly.

View File

@ -19,22 +19,22 @@ class TestOutputModeration:
return ModerationRule(type="keywords", config={"keywords": "badword"})
@pytest.fixture
def output_moderation(self, mock_queue_manager, moderation_rule):
def output_moderation(self, mock_queue_manager, moderation_rule: ModerationRule):
return OutputModeration(
tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager
)
def test_should_direct_output(self, output_moderation):
def test_should_direct_output(self, output_moderation: OutputModeration):
assert output_moderation.should_direct_output() is False
output_moderation.final_output = "blocked"
assert output_moderation.should_direct_output() is True
def test_get_final_output(self, output_moderation):
def test_get_final_output(self, output_moderation: OutputModeration):
assert output_moderation.get_final_output() == ""
output_moderation.final_output = "blocked"
assert output_moderation.get_final_output() == "blocked"
def test_append_new_token(self, output_moderation):
def test_append_new_token(self, output_moderation: OutputModeration):
with patch.object(OutputModeration, "start_thread") as mock_start:
output_moderation.append_new_token("hello")
assert output_moderation.buffer == "hello"
@ -45,7 +45,7 @@ class TestOutputModeration:
assert output_moderation.buffer == "hello world"
assert mock_start.call_count == 1
def test_moderation_completion_no_flag(self, output_moderation):
def test_moderation_completion_no_flag(self, output_moderation: OutputModeration):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
@ -55,7 +55,7 @@ class TestOutputModeration:
assert flagged is False
assert output_moderation.is_final_chunk is True
def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager):
def test_moderation_completion_flagged_direct_output(self, output_moderation: OutputModeration, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset"
@ -71,7 +71,7 @@ class TestOutputModeration:
assert args[0].text == "preset"
assert args[1] == PublishFrom.TASK_PIPELINE
def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager):
def test_moderation_completion_flagged_overridden(self, output_moderation: OutputModeration, mock_queue_manager):
with patch.object(OutputModeration, "moderation") as mock_moderation:
mock_moderation.return_value = ModerationOutputsResult(
flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content"
@ -85,7 +85,7 @@ class TestOutputModeration:
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked content"
def test_start_thread(self, output_moderation):
def test_start_thread(self, output_moderation: OutputModeration):
mock_app = MagicMock(spec=Flask)
with patch("core.moderation.output_moderation.current_app") as mock_current_app:
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
@ -99,7 +99,7 @@ class TestOutputModeration:
mock_thread_class.assert_called_once()
mock_thread_instance.start.assert_called_once()
def test_stop_thread(self, output_moderation):
def test_stop_thread(self, output_moderation: OutputModeration):
mock_thread = MagicMock()
mock_thread.is_alive.return_value = True
output_moderation.thread = mock_thread
@ -113,7 +113,7 @@ class TestOutputModeration:
assert output_moderation.thread_running is True
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_success(self, mock_factory_class, output_moderation):
def test_moderation_success(self, mock_factory_class, output_moderation: OutputModeration):
mock_factory = mock_factory_class.return_value
mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT)
mock_factory.moderation_for_outputs.return_value = mock_result
@ -126,13 +126,13 @@ class TestOutputModeration:
)
@patch("core.moderation.output_moderation.ModerationFactory")
def test_moderation_exception(self, mock_factory_class, output_moderation):
def test_moderation_exception(self, mock_factory_class, output_moderation: OutputModeration):
mock_factory_class.side_effect = Exception("error")
result = output_moderation.moderation("tenant", "app", "buffer")
assert result is None
def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager):
def test_worker_loop_and_exit(self, output_moderation: OutputModeration, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
# Test exit on thread_running=False
@ -140,7 +140,7 @@ class TestOutputModeration:
output_moderation.worker(mock_app, 10)
# Should exit immediately
def test_worker_no_flag(self, output_moderation):
def test_worker_no_flag(self, output_moderation: OutputModeration):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
@ -160,7 +160,7 @@ class TestOutputModeration:
assert mock_moderation.called
def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager):
def test_worker_flagged_direct_output(self, output_moderation: OutputModeration, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
@ -177,7 +177,7 @@ class TestOutputModeration:
mock_queue_manager.publish.assert_called_once()
# It breaks on DIRECT_OUTPUT
def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager):
def test_worker_flagged_overridden(self, output_moderation: OutputModeration, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
@ -199,7 +199,7 @@ class TestOutputModeration:
args, _ = mock_queue_manager.publish.call_args
assert args[0].text == "masked"
def test_worker_chunk_too_small(self, output_moderation):
def test_worker_chunk_too_small(self, output_moderation: OutputModeration):
mock_app = MagicMock(spec=Flask)
with patch("time.sleep") as mock_sleep:
# chunk_length < buffer_size and not is_final_chunk
@ -215,7 +215,7 @@ class TestOutputModeration:
mock_sleep.assert_called_once_with(1)
def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager):
def test_worker_empty_not_flagged(self, output_moderation: OutputModeration, mock_queue_manager):
mock_app = MagicMock(spec=Flask)
with patch.object(OutputModeration, "moderation") as mock_moderation:
# Return None (exception or no rule)

View File

@ -262,7 +262,7 @@ def workflow_repo_fixture(monkeypatch: pytest.MonkeyPatch):
@pytest.fixture
def trace_task_message(monkeypatch, mock_db):
def trace_task_message(monkeypatch: pytest.MonkeyPatch, mock_db):
message_data = make_message_data()
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data)
configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id"))
@ -353,7 +353,7 @@ def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch: pytest.Mo
assert OpsTraceManager.get_ops_trace_instance("app-id") is None
def test_get_ops_trace_instance_success(monkeypatch, mock_db):
def test_get_ops_trace_instance_success(monkeypatch: pytest.MonkeyPatch, mock_db):
app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"}))
mock_db.get.return_value = app
monkeypatch.setattr(
@ -497,7 +497,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message):
assert result.documents == [{"doc": "value"}]
def test_trace_task_tool_trace(monkeypatch, mock_db):
def test_trace_task_tool_trace(monkeypatch: pytest.MonkeyPatch, mock_db):
custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))])
monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message)
configure_db_scalar(mock_db, message_file=FakeMessageFile())

View File

@ -87,6 +87,7 @@ from uuid import uuid4
import pytest
from flask import Flask
from flask.testing import FlaskClient
from flask_restx import Api
from controllers.console.datasets.datasets import DatasetApi, DatasetListApi
@ -339,7 +340,7 @@ class TestDatasetListApi:
mock_get_user.return_value = (mock_user, mock_tenant_id)
yield mock_get_user
def test_get_datasets_success(self, client, mock_current_user):
def test_get_datasets_success(self, client: FlaskClient, mock_current_user):
"""
Test successful retrieval of dataset list.
@ -380,7 +381,7 @@ class TestDatasetListApi:
# Verify service was called
mock_get_datasets.assert_called_once()
def test_get_datasets_with_search(self, client, mock_current_user):
def test_get_datasets_with_search(self, client: FlaskClient, mock_current_user):
"""
Test dataset listing with search keyword.
@ -410,7 +411,7 @@ class TestDatasetListApi:
call_args = mock_get_datasets.call_args
assert call_args[1]["search"] == search_keyword
def test_get_datasets_with_pagination(self, client, mock_current_user):
def test_get_datasets_with_pagination(self, client: FlaskClient, mock_current_user):
"""
Test dataset listing with pagination parameters.
@ -495,7 +496,7 @@ class TestDatasetApiGet:
mock_get_user.return_value = (mock_user, mock_tenant_id)
yield mock_get_user
def test_get_dataset_success(self, client, mock_current_user):
def test_get_dataset_success(self, client: FlaskClient, mock_current_user):
"""
Test successful retrieval of a single dataset.
@ -533,7 +534,7 @@ class TestDatasetApiGet:
mock_get_dataset.assert_called_once_with(dataset_id)
mock_check_perm.assert_called_once()
def test_get_dataset_not_found(self, client, mock_current_user):
def test_get_dataset_not_found(self, client: FlaskClient, mock_current_user):
"""
Test error handling when dataset is not found.
@ -611,7 +612,7 @@ class TestDatasetApiCreate:
mock_get_user.return_value = (mock_user, mock_tenant_id)
yield mock_get_user
def test_create_dataset_success(self, client, mock_current_user):
def test_create_dataset_success(self, client: FlaskClient, mock_current_user):
"""
Test successful creation of a dataset.
@ -706,7 +707,7 @@ class TestHitTestingApi:
mock_get_user.return_value = (mock_user, mock_tenant_id)
yield mock_get_user
def test_hit_testing_success(self, client, mock_current_user):
def test_hit_testing_success(self, client: FlaskClient, mock_current_user):
"""
Test successful hit testing operation.

View File

@ -52,7 +52,7 @@ class TestFileService:
@patch("services.file_service.extract_tenant_id")
@patch("services.file_service.file_helpers.get_signed_file_url")
def test_upload_file_success(
self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session
self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service: FileService, mock_db_session
):
# Setup
mock_tenant_id.return_value = "tenant_id"
@ -88,7 +88,7 @@ class TestFileService:
with pytest.raises(ValueError, match="Filename contains invalid characters"):
file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock())
def test_upload_file_long_filename(self, file_service, mock_db_session):
def test_upload_file_long_filename(self, file_service: FileService, mock_db_session):
# Setup
long_name = "a" * 210 + ".txt"
user = MagicMock(spec=Account)
@ -124,7 +124,7 @@ class TestFileService:
with pytest.raises(FileTooLargeError):
file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock())
def test_upload_file_end_user(self, file_service, mock_db_session):
def test_upload_file_end_user(self, file_service: FileService, mock_db_session):
user = MagicMock(spec=EndUser)
user.id = "end_user_id"
@ -160,7 +160,7 @@ class TestFileService:
assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True
assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False
def test_get_file_base64_success(self, file_service, mock_db_session):
def test_get_file_base64_success(self, file_service: FileService, mock_db_session):
# Setup
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
@ -177,12 +177,12 @@ class TestFileService:
assert result == base64.b64encode(b"test content").decode()
mock_storage.load_once.assert_called_once_with("test_key")
def test_get_file_base64_not_found(self, file_service, mock_db_session):
def test_get_file_base64_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_base64("non_existent")
def test_upload_text_success(self, file_service, mock_db_session):
def test_upload_text_success(self, file_service: FileService, mock_db_session):
# Setup
text = "sample text"
text_name = "test.txt"
@ -204,13 +204,13 @@ class TestFileService:
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_upload_text_long_name(self, file_service, mock_db_session):
def test_upload_text_long_name(self, file_service: FileService, mock_db_session):
long_name = "a" * 210
with patch("services.file_service.storage"):
result = file_service.upload_text("text", long_name, "user", "tenant")
assert len(result.name) == 200
def test_get_file_preview_success(self, file_service, mock_db_session):
def test_get_file_preview_success(self, file_service: FileService, mock_db_session):
# Setup
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
@ -226,12 +226,12 @@ class TestFileService:
# Assert
assert result == "Extracted text content"
def test_get_file_preview_not_found(self, file_service, mock_db_session):
def test_get_file_preview_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_preview("non_existent", "tenant_id")
def test_get_file_preview_unsupported_type(self, file_service, mock_db_session):
def test_get_file_preview_unsupported_type(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "exe"
@ -239,7 +239,7 @@ class TestFileService:
with pytest.raises(UnsupportedFileTypeError):
file_service.get_file_preview("file_id", "tenant_id")
def test_get_image_preview_success(self, file_service, mock_db_session):
def test_get_image_preview_success(self, file_service: FileService, mock_db_session):
# Setup
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
@ -268,14 +268,14 @@ class TestFileService:
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
def test_get_image_preview_not_found(self, file_service, mock_db_session):
def test_get_image_preview_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
def test_get_image_preview_unsupported_type(self, file_service, mock_db_session):
def test_get_image_preview_unsupported_type(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "txt"
@ -285,7 +285,7 @@ class TestFileService:
with pytest.raises(UnsupportedFileTypeError):
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session):
def test_get_file_generator_by_file_id_success(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
@ -308,14 +308,14 @@ class TestFileService:
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
def test_get_file_generator_by_file_id_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
def test_get_public_image_preview_success(self, file_service, mock_db_session):
def test_get_public_image_preview_success(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "png"
@ -329,12 +329,12 @@ class TestFileService:
assert gen == b"image content"
assert mime == "image/png"
def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
def test_get_public_image_preview_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_public_image_preview("file_id")
def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session):
def test_get_public_image_preview_unsupported_type(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "txt"
@ -342,7 +342,7 @@ class TestFileService:
with pytest.raises(UnsupportedFileTypeError):
file_service.get_public_image_preview("file_id")
def test_get_file_content_success(self, file_service, mock_db_session):
def test_get_file_content_success(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
@ -353,12 +353,12 @@ class TestFileService:
result = file_service.get_file_content("file_id")
assert result == "hello world"
def test_get_file_content_not_found(self, file_service, mock_db_session):
def test_get_file_content_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_content("file_id")
def test_delete_file_success(self, file_service, mock_db_session):
def test_delete_file_success(self, file_service: FileService, mock_db_session):
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
@ -370,7 +370,7 @@ class TestFileService:
mock_storage.delete.assert_called_once_with("key")
mock_db_session.delete.assert_called_once_with(upload_file)
def test_delete_file_not_found(self, file_service, mock_db_session):
def test_delete_file_not_found(self, file_service: FileService, mock_db_session):
mock_db_session.scalar.return_value = None
file_service.delete_file("file_id")
# Should return without doing anything

View File

@ -151,9 +151,9 @@ class TestErrorHandling:
def test_clean_dataset_task_rollback_failure_still_closes_session(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
mock_db_session,
mock_storage,
mock_index_processor_factory,
@ -198,9 +198,9 @@ class TestPipelineAndWorkflowDeletion:
def test_clean_dataset_task_with_pipeline_id(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
pipeline_id,
mock_db_session,
mock_storage,
@ -231,9 +231,9 @@ class TestPipelineAndWorkflowDeletion:
def test_clean_dataset_task_without_pipeline_id(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
mock_db_session,
mock_storage,
mock_index_processor_factory,
@ -271,9 +271,9 @@ class TestSegmentAttachmentCleanup:
def test_clean_dataset_task_with_attachments(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
mock_db_session,
mock_storage,
mock_index_processor_factory,
@ -321,9 +321,9 @@ class TestSegmentAttachmentCleanup:
def test_clean_dataset_task_attachment_storage_failure(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
mock_db_session,
mock_storage,
mock_index_processor_factory,
@ -375,9 +375,9 @@ class TestEdgeCases:
def test_clean_dataset_task_session_always_closed(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
mock_db_session,
mock_storage,
mock_index_processor_factory,
@ -413,9 +413,9 @@ class TestIndexProcessorParameters:
def test_clean_dataset_task_passes_correct_parameters_to_index_processor(
self,
dataset_id,
tenant_id,
collection_binding_id,
dataset_id: str,
tenant_id: str,
collection_binding_id: str,
mock_db_session,
mock_storage,
mock_index_processor_factory,

View File

@ -124,10 +124,10 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
notion_workspace_id,
notion_page_id,
dataset_id: str,
document_id: str,
notion_workspace_id: str,
notion_page_id: str,
):
"""Test that NotionExtractor is initialized with expected arguments."""
# Arrange
@ -151,9 +151,9 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
credential_id,
dataset_id: str,
document_id: str,
credential_id: str,
):
"""Test that datasource credentials are requested with expected identifiers."""
# Arrange
@ -176,8 +176,8 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
dataset_id: str,
document_id: str,
):
"""Test that missing credential_id is forwarded as None."""
# Arrange
@ -212,8 +212,8 @@ class TestDataSourceInfoSerialization:
self,
mock_document,
mock_dataset,
dataset_id,
document_id,
dataset_id: str,
document_id: str,
):
"""data_source_info must be serialized with json.dumps before DB write."""
with (