test: add type to test (#35871)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-05-08 10:06:25 +09:00 committed by GitHub
parent 203b3a9499
commit ecd830083a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 706 additions and 656 deletions

View File

@ -90,7 +90,7 @@ class TestOAuthLogin:
mock_redirect,
mock_get_providers,
resource,
app,
app: Flask,
mock_oauth_provider,
invite_token,
expected_token,
@ -165,7 +165,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -218,7 +218,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -262,7 +262,7 @@ class TestOAuthCallback:
mock_tenant_service,
mock_account_service,
resource,
app,
app: Flask,
oauth_setup,
account_status,
expected_redirect,
@ -301,7 +301,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -337,7 +337,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
"""Defensive test for CLOSED account status handling in OAuth callback.
@ -466,7 +466,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
allow_register,
@ -505,7 +505,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
@ -530,7 +530,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_tenant_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
):

View File

@ -47,7 +47,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
):
# Arrange
@ -105,7 +105,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -154,7 +154,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
"""
Test successful verification code validation.
@ -201,7 +201,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
@ -345,7 +345,7 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
app,
app: Flask,
mock_account,
):
"""

View File

@ -30,7 +30,7 @@ class TestPipelineTemplateListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateListApi()
method = unwrap(api.get)
@ -54,7 +54,7 @@ class TestPipelineTemplateDetailApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -75,7 +75,7 @@ class TestPipelineTemplateDetailApi:
assert status == 200
assert response == template
def test_get_returns_404_when_template_not_found(self, app):
def test_get_returns_404_when_template_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -94,7 +94,7 @@ class TestPipelineTemplateDetailApi:
assert status == 404
assert "error" in response
def test_get_returns_404_for_customized_type_not_found(self, app):
def test_get_returns_404_for_customized_type_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -119,7 +119,7 @@ class TestCustomizedPipelineTemplateApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
@ -141,7 +141,7 @@ class TestCustomizedPipelineTemplateApi:
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
@ -156,7 +156,7 @@ class TestCustomizedPipelineTemplateApi:
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app, db_session_with_containers: Session):
def test_post_success(self, app: Flask, db_session_with_containers: Session):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -183,7 +183,7 @@ class TestCustomizedPipelineTemplateApi:
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
def test_post_template_not_found(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -197,7 +197,7 @@ class TestPublishCustomizedPipelineTemplateApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)

View File

@ -36,7 +36,7 @@ class TestRagPipelineImportApi:
"name": "Test",
}
def test_post_success_200(self, app):
def test_post_success_200(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -66,7 +66,7 @@ class TestRagPipelineImportApi:
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
def test_post_failed_400(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -96,7 +96,7 @@ class TestRagPipelineImportApi:
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
def test_post_pending_202(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -132,7 +132,7 @@ class TestRagPipelineImportConfirmApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_confirm_success(self, app):
def test_confirm_success(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -160,7 +160,7 @@ class TestRagPipelineImportConfirmApi:
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
def test_confirm_failed(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -194,7 +194,7 @@ class TestRagPipelineImportCheckDependenciesApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
@ -223,7 +223,7 @@ class TestRagPipelineExportApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_with_include_secret(self, app):
def test_get_with_include_secret(self, app: Flask):
api = RagPipelineExportApi()
method = unwrap(api.get)

View File

@ -391,7 +391,7 @@ class TestPublishedPipelineApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi()

View File

@ -55,7 +55,7 @@ class TestDataSourceApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -79,7 +79,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
def test_get_no_bindings(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -95,7 +95,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -116,7 +116,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -137,7 +137,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -152,7 +152,7 @@ class TestDataSourceApi:
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -169,7 +169,7 @@ class TestDataSourceApi:
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -192,7 +192,7 @@ class TestDataSourceNotionListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_credential_not_found(self, app, patch_tenant):
def test_get_credential_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -206,7 +206,7 @@ class TestDataSourceNotionListApi:
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -247,7 +247,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -300,7 +300,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -327,7 +327,7 @@ class TestDataSourceNotionApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_preview_success(self, app, patch_tenant):
def test_get_preview_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
@ -348,7 +348,7 @@ class TestDataSourceNotionApi:
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
@ -385,7 +385,7 @@ class TestDataSourceNotionDatasetSyncApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -408,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi:
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -428,7 +428,7 @@ class TestDataSourceNotionDocumentSyncApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
@ -451,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi:
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
def test_get_document_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)

View File

@ -57,7 +57,7 @@ class TestConversationListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
def test_get_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -82,7 +82,7 @@ class TestConversationListApi:
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app, chat_app, user):
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -98,7 +98,7 @@ class TestConversationListApi:
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app, non_chat_app):
def test_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -112,7 +112,7 @@ class TestConversationApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
def test_delete_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -130,7 +130,7 @@ class TestConversationApi:
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app, chat_app, user):
def test_delete_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -146,7 +146,7 @@ class TestConversationApi:
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app, non_chat_app):
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -160,7 +160,7 @@ class TestConversationRenameApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
def test_rename_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -179,7 +179,7 @@ class TestConversationRenameApi:
assert result["id"] == "cid"
def test_rename_not_found(self, app, chat_app, user):
def test_rename_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -201,7 +201,7 @@ class TestConversationPinApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
def test_pin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
@ -223,7 +223,7 @@ class TestConversationUnPinApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):
def test_unpin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)

View File

@ -49,7 +49,7 @@ class TestTriggerProviderApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = TriggerProviderIconApi()
method = unwrap(api.get)
@ -63,7 +63,7 @@ class TestTriggerProviderApis:
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
def test_list_providers(self, app: Flask):
api = TriggerProviderListApi()
method = unwrap(api.get)
@ -77,7 +77,7 @@ class TestTriggerProviderApis:
):
assert method(api) == []
def test_provider_info(self, app):
def test_provider_info(self, app: Flask):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
@ -97,7 +97,7 @@ class TestTriggerSubscriptionListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -111,7 +111,7 @@ class TestTriggerSubscriptionListApi:
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
def test_list_invalid_provider(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -132,7 +132,7 @@ class TestTriggerSubscriptionBuilderApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create_builder(self, app):
def test_create_builder(self, app: Flask):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
@ -147,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis:
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
def test_get_builder(self, app: Flask):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
@ -160,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
def test_verify_builder(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -174,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
def test_verify_builder_error(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -189,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis:
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
def test_update_builder(self, app: Flask):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
@ -203,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
def test_logs(self, app: Flask):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
@ -220,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
def test_build(self, app: Flask):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
@ -240,7 +240,7 @@ class TestTriggerSubscriptionCrud:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_update_rename_only(self, app):
def test_update_rename_only(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -259,7 +259,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
def test_update_not_found(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -274,7 +274,7 @@ class TestTriggerSubscriptionCrud:
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
def test_update_rebuild(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -297,7 +297,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
def test_delete_subscription(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -320,7 +320,7 @@ class TestTriggerSubscriptionCrud:
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
def test_delete_subscription_value_error(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -346,7 +346,7 @@ class TestTriggerOAuthApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_authorize_success(self, app):
def test_oauth_authorize_success(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -373,7 +373,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
def test_oauth_authorize_no_client(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -388,7 +388,7 @@ class TestTriggerOAuthApis:
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
def test_oauth_callback_forbidden(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -396,7 +396,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
def test_oauth_callback_success(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -426,7 +426,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
def test_oauth_callback_no_oauth_client(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -450,7 +450,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
def test_oauth_callback_empty_credentials(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -484,7 +484,7 @@ class TestTriggerOAuthClientManageApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_client(self, app):
def test_get_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
@ -511,7 +511,7 @@ class TestTriggerOAuthClientManageApi:
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
def test_post_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -525,7 +525,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
def test_delete_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
@ -539,7 +539,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
def test_oauth_client_post_value_error(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -560,7 +560,7 @@ class TestTriggerSubscriptionVerifyApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_verify_success(self, app):
def test_verify_success(self, app: Flask):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)

View File

@ -291,7 +291,7 @@ class TestDatasetListApiGet:
mock_current_user,
mock_provider_mgr,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -326,7 +326,7 @@ class TestDatasetListApiPost:
mock_dataset_svc,
mock_current_user,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -352,7 +352,7 @@ class TestDatasetListApiPost:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -390,7 +390,7 @@ class TestDatasetApiGet:
mock_provider_mgr,
mock_marshal,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -440,7 +440,7 @@ class TestDatasetApiGet:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -468,7 +468,7 @@ class TestDatasetApiDelete:
mock_dataset_svc,
mock_current_user,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -490,7 +490,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -511,7 +511,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -543,7 +543,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -574,7 +574,7 @@ class TestDocumentStatusApiPatch:
def test_batch_update_status_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -603,7 +603,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -636,7 +636,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -669,7 +669,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -709,7 +709,7 @@ class TestDatasetTagsApiGet:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -731,7 +731,7 @@ class TestDatasetTagsApiGet:
def test_list_tags_from_db(
self,
mock_current_user,
app,
app: Flask,
db_session_with_containers: Session,
):
"""Integration test: creates real Tag rows and retrieves them
@ -774,7 +774,7 @@ class TestDatasetTagsApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -797,7 +797,7 @@ class TestDatasetTagsApiPost:
mock_tag_svc.save_tags.assert_called_once()
@patch("controllers.service_api.dataset.dataset.current_user")
def test_create_tag_forbidden(self, mock_current_user, app):
def test_create_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -826,7 +826,7 @@ class TestDatasetTagsApiPatch:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -852,7 +852,7 @@ class TestDatasetTagsApiPatch:
mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_update_tag_forbidden(self, mock_current_user, app):
def test_update_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -880,7 +880,7 @@ class TestDatasetTagsApiDelete:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -905,7 +905,7 @@ class TestDatasetTagsApiDelete:
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
@patch("libs.login.current_user")
def test_delete_tag_forbidden(self, mock_current_user, app):
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
user_obj = Mock(spec=Account)
@ -933,7 +933,7 @@ class TestDatasetTagsBindingStatusApi:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi
@ -963,7 +963,7 @@ class TestDatasetTagBindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
@ -988,7 +988,7 @@ class TestDatasetTagBindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_bind_tags_forbidden(self, mock_current_user, app):
def test_bind_tags_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
mock_current_user.__class__ = Account
@ -1014,7 +1014,7 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1044,7 +1044,7 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1069,7 +1069,7 @@ class TestDatasetTagUnbindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_unbind_tag_forbidden(self, mock_current_user, app):
def test_unbind_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
mock_current_user.__class__ = Account

View File

@ -240,7 +240,7 @@ class TestDecodeJwtToken:
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
@ -300,7 +300,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
@ -325,7 +325,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
@ -351,7 +351,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)

View File

@ -21,6 +21,8 @@ from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
from extensions.ext_redis import redis_client
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
TenantAndAccount = tuple[Tenant, Account]
@dataclass
class TestTask:
@ -74,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration:
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
def test_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
def secondary_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
@ -95,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker):
def test_tenant_isolation(
self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker
):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
@ -115,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
def test_key_isolation(self, test_tenant_and_account: TenantAndAccount):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
@ -293,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker):
def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
@ -436,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker):
def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test compatibility with legacy queues containing only string data.
@ -466,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker):
def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test complete migration scenario from legacy to new system.
@ -547,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker):
def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test error recovery when legacy queue contains malformed data.

View File

@ -7,6 +7,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import App
from models.model import EndUser
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@ -184,7 +185,7 @@ class TestAppGenerateService:
return app, account
def _create_test_workflow(self, db_session_with_containers: Session, app):
def _create_test_workflow(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test workflow for testing.

View File

@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration:
return app
def _create_conversation(self, db_session_with_containers: Session, app):
def _create_conversation(self, db_session_with_containers: Session, app: App):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import App, CreatorUserRole
from models.enums import ConversationFromSource
from models.model import EndUser, Message
from models.web import SavedMessage
@ -88,7 +89,7 @@ class TestSavedMessageService:
return app, account
def _create_test_end_user(self, db_session_with_containers: Session, app):
def _create_test_end_user(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test end user for testing.
@ -116,7 +117,7 @@ class TestSavedMessageService:
return end_user
def _create_test_message(self, db_session_with_containers: Session, app, user):
def _create_test_message(self, db_session_with_containers: Session, app: App, user):
"""
Helper method to create a test message for testing.
@ -199,13 +200,13 @@ class TestSavedMessageService:
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -272,13 +273,13 @@ class TestSavedMessageService:
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="end_user",
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="end_user",
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
@ -449,7 +450,7 @@ class TestSavedMessageService:
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -540,7 +541,9 @@ class TestSavedMessageService:
message = self._create_test_message(db_session_with_containers, app, account)
# Pre-create a saved message
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id)
saved = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id
)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
@ -571,7 +574,9 @@ class TestSavedMessageService:
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id)
saved = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id
)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
@ -596,10 +601,10 @@ class TestSavedMessageService:
# Both users save the same message
saved_account = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account1.id
)
saved_end_user = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id
)
db_session_with_containers.add_all([saved_account, saved_end_user])
db_session_with_containers.commit()

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models import Account, App
from models.enums import ConversationFromSource
from models.model import Conversation, EndUser
from models.web import PinnedConversation
@ -93,7 +93,7 @@ class TestWebConversationService:
return app, account
def _create_test_end_user(self, db_session_with_containers: Session, app):
def _create_test_end_user(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test end user for testing.

View File

@ -185,7 +185,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):
@ -263,7 +263,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):
@ -312,7 +312,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
language,
@ -358,7 +358,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
):
"""
@ -398,7 +398,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
):
"""
@ -438,7 +438,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):

View File

@ -195,7 +195,7 @@ class TestEmailCodeLoginSendEmailApi:
mock_get_user,
mock_is_ip_limit,
mock_db,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -267,7 +267,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -315,7 +315,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -431,7 +431,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""
@ -474,7 +474,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""
@ -515,7 +515,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""

View File

@ -412,7 +412,7 @@ class TestLoginApi:
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -448,7 +448,7 @@ class TestLoginApi:
mock_revoke_token,
mock_get_token_data,
mock_db,
app,
app: Flask,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")

View File

@ -74,7 +74,7 @@ class TestRefreshTokenApi:
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
def test_refresh_fails_without_token(self, mock_extract_token, app):
def test_refresh_fails_without_token(self, mock_extract_token, app: Flask):
"""
Test token refresh failure when no refresh token provided.
@ -98,7 +98,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh failure with invalid refresh token.
@ -123,7 +123,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh failure with expired refresh token.
@ -148,7 +148,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh with empty string token.

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import console_ns
@ -29,7 +30,7 @@ def unwrap(func):
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -61,7 +62,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
assert response.status_code == 200
def test_get_no_oauth_config(self, app):
def test_get_no_oauth_config(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -80,7 +81,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_without_credential_id_sets_cookie(self, app):
def test_get_without_credential_id_sets_cookie(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -115,7 +116,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app):
def test_callback_success_new_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -157,7 +158,7 @@ class TestDatasourceOAuthCallback:
assert response.status_code == 302
def test_callback_missing_context(self, app):
def test_callback_missing_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -165,7 +166,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_invalid_context(self, app):
def test_callback_invalid_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -180,7 +181,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_oauth_config_not_found(self, app):
def test_callback_oauth_config_not_found(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -202,7 +203,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(NotFound):
method(api, "notion")
def test_callback_reauthorize_existing_credential(self, app):
def test_callback_reauthorize_existing_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -245,7 +246,7 @@ class TestDatasourceOAuthCallback:
assert response.status_code == 302
assert "/oauth-callback" in response.location
def test_callback_context_id_from_cookie(self, app):
def test_callback_context_id_from_cookie(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -289,7 +290,7 @@ class TestDatasourceOAuthCallback:
class TestDatasourceAuth:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -312,7 +313,7 @@ class TestDatasourceAuth:
assert status == 200
def test_post_invalid_credentials(self, app):
def test_post_invalid_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -334,7 +335,7 @@ class TestDatasourceAuth:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
@ -355,7 +356,7 @@ class TestDatasourceAuth:
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app):
def test_post_missing_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -372,7 +373,7 @@ class TestDatasourceAuth:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_empty_list(self, app):
def test_get_empty_list(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
@ -395,7 +396,7 @@ class TestDatasourceAuth:
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
@ -418,7 +419,7 @@ class TestDatasourceAuthDeleteApi:
assert status == 200
def test_delete_missing_credential_id(self, app):
def test_delete_missing_credential_id(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
@ -437,7 +438,7 @@ class TestDatasourceAuthDeleteApi:
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -460,7 +461,7 @@ class TestDatasourceAuthUpdateApi:
assert status == 201
def test_update_with_credentials_none(self, app):
def test_update_with_credentials_none(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -484,7 +485,7 @@ class TestDatasourceAuthUpdateApi:
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app):
def test_update_name_only(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -507,7 +508,7 @@ class TestDatasourceAuthUpdateApi:
assert status == 201
def test_update_with_empty_credentials_dict(self, app):
def test_update_with_empty_credentials_dict(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -533,7 +534,7 @@ class TestDatasourceAuthUpdateApi:
class TestDatasourceAuthListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
@ -553,7 +554,7 @@ class TestDatasourceAuthListApi:
assert status == 200
def test_auth_list_empty(self, app):
def test_auth_list_empty(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
@ -574,7 +575,7 @@ class TestDatasourceAuthListApi:
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app):
def test_hardcode_list_empty(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
@ -597,7 +598,7 @@ class TestDatasourceAuthListApi:
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
@ -619,7 +620,7 @@ class TestDatasourceHardCodeAuthListApi:
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -642,7 +643,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
@ -662,7 +663,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_post_empty_payload(self, app):
def test_post_empty_payload(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -685,7 +686,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_post_disabled_flag(self, app):
def test_post_disabled_flag(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -714,7 +715,7 @@ class TestDatasourceAuthOauthCustomClient:
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app):
def test_set_default_success(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
@ -737,7 +738,7 @@ class TestDatasourceAuthDefaultApi:
assert status == 200
def test_default_missing_id(self, app):
def test_default_missing_id(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
@ -756,7 +757,7 @@ class TestDatasourceAuthDefaultApi:
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app):
def test_update_name_success(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
@ -779,7 +780,7 @@ class TestDatasourceUpdateProviderNameApi:
assert status == 200
def test_update_name_too_long(self, app):
def test_update_name_too_long(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
@ -799,7 +800,7 @@ class TestDatasourceUpdateProviderNameApi:
with pytest.raises(ValueError):
method(api, "notion")
def test_update_name_missing_credential_id(self, app):
def test_update_name_missing_credential_id(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
@ -25,7 +26,7 @@ class TestDataSourceContentPreviewApi:
"credential_id": "cred-1",
}
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -66,7 +67,7 @@ class TestDataSourceContentPreviewApi:
assert status == 200
assert response == preview_result
def test_post_forbidden_non_account_user(self, app):
def test_post_forbidden_non_account_user(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -85,7 +86,7 @@ class TestDataSourceContentPreviewApi:
with pytest.raises(Forbidden):
method(api, pipeline, "node-1")
def test_post_invalid_payload(self, app):
def test_post_invalid_payload(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -108,7 +109,7 @@ class TestDataSourceContentPreviewApi:
with pytest.raises(ValueError):
method(api, pipeline, "node-1")
def test_post_without_credential_id(self, app):
def test_post_without_credential_id(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)

View File

@ -2,6 +2,7 @@ import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import services
@ -58,7 +59,7 @@ class TestDatasetList:
user.is_dataset_editor = True
return user
def test_get_success_basic(self, app):
def test_get_success_basic(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -93,7 +94,7 @@ class TestDatasetList:
assert resp["total"] == 1
assert resp["data"][0]["embedding_available"] is True
def test_get_with_ids_filter(self, app):
def test_get_with_ids_filter(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -128,7 +129,7 @@ class TestDatasetList:
assert status == 200
assert resp["total"] == 2
def test_get_with_tag_ids(self, app):
def test_get_with_tag_ids(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -161,7 +162,7 @@ class TestDatasetList:
assert status == 200
def test_embedding_available_false(self, app):
def test_embedding_available_false(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -203,7 +204,7 @@ class TestDatasetList:
assert resp["data"][0]["embedding_available"] is False
def test_partial_members_permission(self, app):
def test_partial_members_permission(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -242,7 +243,7 @@ class TestDatasetList:
class TestDatasetListApiPost:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -290,7 +291,7 @@ class TestDatasetListApiPost:
assert status == 201
def test_post_forbidden(self, app):
def test_post_forbidden(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -310,7 +311,7 @@ class TestDatasetListApiPost:
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
def test_post_duplicate_name(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -335,7 +336,7 @@ class TestDatasetListApiPost:
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload_missing_name(self, app):
def test_post_invalid_payload_missing_name(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -343,7 +344,7 @@ class TestDatasetListApiPost:
with pytest.raises(ValueError):
method(api)
def test_post_invalid_indexing_technique(self, app):
def test_post_invalid_indexing_technique(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -356,7 +357,7 @@ class TestDatasetListApiPost:
with pytest.raises(ValueError, match="Invalid indexing technique"):
method(api)
def test_post_invalid_provider(self, app):
def test_post_invalid_provider(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -371,7 +372,7 @@ class TestDatasetListApiPost:
class TestDatasetApiGet:
def test_get_success_basic(self, app):
def test_get_success_basic(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -427,7 +428,7 @@ class TestDatasetApiGet:
assert status == 200
assert data["embedding_available"] is True
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -448,7 +449,7 @@ class TestDatasetApiGet:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -475,7 +476,7 @@ class TestDatasetApiGet:
with pytest.raises(Forbidden, match="no access"):
method(api, dataset_id)
def test_get_high_quality_embedding_unavailable(self, app):
def test_get_high_quality_embedding_unavailable(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -530,7 +531,7 @@ class TestDatasetApiGet:
assert data["embedding_available"] is False
def test_get_partial_members_permission(self, app):
def test_get_partial_members_permission(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -590,7 +591,7 @@ class TestDatasetApiGet:
class TestDatasetApiPatch:
def test_patch_success_basic(self, app):
def test_patch_success_basic(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -659,7 +660,7 @@ class TestDatasetApiPatch:
assert status == 200
assert result["partial_member_list"] == []
def test_patch_dataset_not_found(self, app):
def test_patch_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -674,7 +675,7 @@ class TestDatasetApiPatch:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, "missing")
def test_patch_permission_denied(self, app):
def test_patch_permission_denied(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -704,7 +705,7 @@ class TestDatasetApiPatch:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_patch_partial_members_update(self, app):
def test_patch_partial_members_update(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -773,7 +774,7 @@ class TestDatasetApiPatch:
assert result["partial_member_list"] == payload["partial_member_list"]
def test_patch_clear_partial_members(self, app):
def test_patch_clear_partial_members(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -843,7 +844,7 @@ class TestDatasetApiPatch:
class TestDatasetApiDelete:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -874,7 +875,7 @@ class TestDatasetApiDelete:
assert status == 204
assert result == {"result": "success"}
def test_delete_forbidden_no_permission(self, app):
def test_delete_forbidden_no_permission(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -893,7 +894,7 @@ class TestDatasetApiDelete:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_delete_dataset_not_found(self, app):
def test_delete_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -917,7 +918,7 @@ class TestDatasetApiDelete:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_delete_dataset_in_use(self, app):
def test_delete_dataset_in_use(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -943,7 +944,7 @@ class TestDatasetApiDelete:
class TestDatasetUseCheckApi:
def test_get_use_check_true(self, app):
def test_get_use_check_true(self, app: Flask):
api = DatasetUseCheckApi()
method = unwrap(api.get)
@ -962,7 +963,7 @@ class TestDatasetUseCheckApi:
assert status == 200
assert result == {"is_using": True}
def test_get_use_check_false(self, app):
def test_get_use_check_false(self, app: Flask):
api = DatasetUseCheckApi()
method = unwrap(api.get)
@ -983,7 +984,7 @@ class TestDatasetUseCheckApi:
class TestDatasetQueryApi:
def test_get_queries_success(self, app):
def test_get_queries_success(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1027,7 +1028,7 @@ class TestDatasetQueryApi:
assert response["has_more"] is False
assert len(response["data"]) == 2
def test_get_queries_dataset_not_found(self, app):
def test_get_queries_dataset_not_found(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1049,7 +1050,7 @@ class TestDatasetQueryApi:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_queries_permission_denied(self, app):
def test_get_queries_permission_denied(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1078,7 +1079,7 @@ class TestDatasetQueryApi:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_get_queries_pagination_has_more(self, app):
def test_get_queries_pagination_has_more(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1152,7 +1153,7 @@ class TestDatasetIndexingEstimateApi:
"dataset_id": None,
}
def test_post_success_upload_file(self, app):
def test_post_success_upload_file(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
@ -1193,7 +1194,7 @@ class TestDatasetIndexingEstimateApi:
assert status == 200
assert response == {"tokens": 100}
def test_post_file_not_found(self, app):
def test_post_file_not_found(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
@ -1223,7 +1224,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(NotFound):
method(api)
def test_post_llm_bad_request_error(self, app):
def test_post_llm_bad_request_error(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1258,7 +1259,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_provider_token_not_init(self, app):
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1293,7 +1294,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_generic_exception(self, app):
def test_post_generic_exception(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1330,7 +1331,7 @@ class TestDatasetIndexingEstimateApi:
class TestDatasetRelatedAppListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1368,7 +1369,7 @@ class TestDatasetRelatedAppListApi:
assert response["total"] == 2
assert response["data"] == [app1, app2]
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1386,7 +1387,7 @@ class TestDatasetRelatedAppListApi:
with pytest.raises(NotFound):
method(api, "dataset-1")
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1410,7 +1411,7 @@ class TestDatasetRelatedAppListApi:
with pytest.raises(Forbidden):
method(api, "dataset-1")
def test_get_filters_none_apps(self, app):
def test_get_filters_none_apps(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1449,7 +1450,7 @@ class TestDatasetRelatedAppListApi:
class TestDatasetIndexingStatusApi:
def test_get_success_with_documents(self, app):
def test_get_success_with_documents(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1490,7 +1491,7 @@ class TestDatasetIndexingStatusApi:
assert item["completed_segments"] == 3
assert item["total_segments"] == 3
def test_get_success_no_documents(self, app):
def test_get_success_no_documents(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1510,7 +1511,7 @@ class TestDatasetIndexingStatusApi:
assert status == 200
assert response == {"data": []}
def test_segment_counts_different_values(self, app):
def test_segment_counts_different_values(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1550,7 +1551,7 @@ class TestDatasetIndexingStatusApi:
class TestDatasetApiKeyApi:
def test_get_api_keys_success(self, app):
def test_get_api_keys_success(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.get)
@ -1587,7 +1588,7 @@ class TestDatasetApiKeyApi:
assert response["data"][1]["id"] == "key-2"
assert response["data"][1]["token"] == "ds-def"
def test_post_create_api_key_success(self, app):
def test_post_create_api_key_success(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.post)
@ -1632,7 +1633,7 @@ class TestDatasetApiKeyApi:
assert response["type"] == "dataset"
assert response["created_at"] is not None
def test_post_exceed_max_keys(self, app):
def test_post_exceed_max_keys(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.post)
@ -1658,7 +1659,7 @@ class TestDatasetApiKeyApi:
class TestDatasetApiDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
@ -1688,7 +1689,7 @@ class TestDatasetApiDeleteApi:
assert status == 204
assert response["result"] == "success"
def test_delete_key_not_found(self, app):
def test_delete_key_not_found(self, app: Flask):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
@ -1708,7 +1709,7 @@ class TestDatasetApiDeleteApi:
class TestDatasetEnableApiApi:
def test_enable_api(self, app):
def test_enable_api(self, app: Flask):
api = DatasetEnableApiApi()
method = unwrap(api.post)
@ -1724,7 +1725,7 @@ class TestDatasetEnableApiApi:
assert status == 200
assert response["result"] == "success"
def test_disable_api(self, app):
def test_disable_api(self, app: Flask):
api = DatasetEnableApiApi()
method = unwrap(api.post)
@ -1742,7 +1743,7 @@ class TestDatasetEnableApiApi:
class TestDatasetApiBaseUrlApi:
def test_get_api_base_url_from_config(self, app):
def test_get_api_base_url_from_config(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1757,7 +1758,7 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "https://example.com/v1"
def test_get_api_base_url_from_request(self, app):
def test_get_api_base_url_from_request(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1772,7 +1773,7 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "http://localhost:5000/v1"
def test_get_api_base_url_no_double_v1(self, app):
def test_get_api_base_url_no_double_v1(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1789,7 +1790,7 @@ class TestDatasetApiBaseUrlApi:
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRetrievalSettingApi()
method = unwrap(api.get)
@ -1810,7 +1811,7 @@ class TestDatasetRetrievalSettingApi:
class TestDatasetRetrievalSettingMockApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRetrievalSettingMockApi()
method = unwrap(api.get)
@ -1827,7 +1828,7 @@ class TestDatasetRetrievalSettingMockApi:
class TestDatasetErrorDocs:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetErrorDocs()
method = unwrap(api.get)
@ -1850,7 +1851,7 @@ class TestDatasetErrorDocs:
assert status == 200
assert response["total"] == 1
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetErrorDocs()
method = unwrap(api.get)
@ -1866,7 +1867,7 @@ class TestDatasetErrorDocs:
class TestDatasetPermissionUserListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
@ -1897,7 +1898,7 @@ class TestDatasetPermissionUserListApi:
assert status == 200
assert response["data"] == users
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
@ -1923,7 +1924,7 @@ class TestDatasetPermissionUserListApi:
class TestDatasetAutoDisableLogApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
@ -1946,7 +1947,7 @@ class TestDatasetAutoDisableLogApi:
assert status == 200
assert response == logs
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -239,7 +240,7 @@ class TestDatasetDocumentListApi:
assert "documents" in response
def test_post_forbidden(self, app):
def test_post_forbidden(self, app: Flask):
api = DatasetDocumentListApi()
method = unwrap(api.post)
@ -395,7 +396,7 @@ class TestDocumentDownloadApi:
class TestDocumentProcessingApi:
def test_processing_forbidden_when_not_editor(self, app):
def test_processing_forbidden_when_not_editor(self, app: Flask):
api = DocumentProcessingApi()
method = unwrap(api.patch)
@ -1185,7 +1186,7 @@ class TestDocumentPermissionCases:
"preview": [],
}
def test_document_tenant_mismatch(self, app):
def test_document_tenant_mismatch(self, app: Flask):
api = DocumentApi()
method = unwrap(api.get)
@ -1253,7 +1254,7 @@ class TestDocumentPermissionCases:
assert status == 200
assert response["mode"] == "custom"
def test_process_rule_permission_denied(self, app):
def test_process_rule_permission_denied(self, app: Flask):
api = GetProcessRuleApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -82,7 +83,7 @@ def test_get_segment_with_summary(monkeypatch):
class TestDatasetDocumentSegmentListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -132,7 +133,7 @@ class TestDatasetDocumentSegmentListApi:
assert status == 200
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -150,7 +151,7 @@ class TestDatasetDocumentSegmentListApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -176,7 +177,7 @@ class TestDatasetDocumentSegmentListApi:
class TestDatasetDocumentSegmentApi:
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -221,7 +222,7 @@ class TestDatasetDocumentSegmentApi:
assert status == 200
assert response["result"] == "success"
def test_patch_document_indexing_in_progress(self, app):
def test_patch_document_indexing_in_progress(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -264,7 +265,7 @@ class TestDatasetDocumentSegmentApi:
with pytest.raises(InvalidActionError):
method(api, "ds-1", "doc-1", "disable")
def test_patch_llm_bad_request(self, app):
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -308,7 +309,7 @@ class TestDatasetDocumentSegmentApi:
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1", "enable")
def test_patch_provider_token_not_init(self, app):
def test_patch_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -354,7 +355,7 @@ class TestDatasetDocumentSegmentApi:
class TestDatasetDocumentSegmentAddApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -413,7 +414,7 @@ class TestDatasetDocumentSegmentAddApi:
assert status == 200
assert response["data"]["id"] == "seg-1"
def test_post_llm_bad_request(self, app):
def test_post_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -452,7 +453,7 @@ class TestDatasetDocumentSegmentAddApi:
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1")
def test_post_provider_token_not_init(self, app):
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -493,7 +494,7 @@ class TestDatasetDocumentSegmentAddApi:
class TestDatasetDocumentSegmentUpdateApi:
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
@ -551,7 +552,7 @@ class TestDatasetDocumentSegmentUpdateApi:
assert status == 200
assert "data" in response
def test_patch_llm_bad_request(self, app):
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
@ -596,7 +597,7 @@ class TestDatasetDocumentSegmentUpdateApi:
class TestDatasetDocumentSegmentBatchImportApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -638,7 +639,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
assert status == 200
assert response["job_status"] == "waiting"
def test_post_dataset_not_found(self, app):
def test_post_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -659,7 +660,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_document_not_found(self, app):
def test_post_document_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -684,7 +685,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_upload_file_not_found(self, app):
def test_post_upload_file_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -713,7 +714,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_invalid_file_type(self, app):
def test_post_invalid_file_type(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -745,7 +746,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(ValueError):
method(api, "ds-1", "doc-1")
def test_post_async_task_failure(self, app):
def test_post_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -783,7 +784,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
assert status == 500
assert "error" in response
def test_get_job_not_found_in_redis(self, app):
def test_get_job_not_found_in_redis(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)
@ -799,7 +800,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
class TestChildChunkAddApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
@ -852,7 +853,7 @@ class TestChildChunkAddApi:
assert status == 200
assert response["data"]["id"] == "cc-1"
def test_post_child_chunk_indexing_error(self, app):
def test_post_child_chunk_indexing_error(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
@ -897,7 +898,7 @@ class TestChildChunkAddApi:
class TestChildChunkUpdateApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
@ -941,7 +942,7 @@ class TestChildChunkUpdateApi:
assert status == 204
assert response["result"] == "success"
def test_delete_child_chunk_index_error(self, app):
def test_delete_child_chunk_index_error(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
@ -984,7 +985,7 @@ class TestChildChunkUpdateApi:
class TestSegmentListAdvancedCases:
def test_segment_list_with_keyword_filter(self, app):
def test_segment_list_with_keyword_filter(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1035,7 +1036,7 @@ class TestSegmentListAdvancedCases:
assert status == 200
assert response["total"] == 1
def test_segment_list_permission_denied(self, app):
def test_segment_list_permission_denied(self, app: Flask):
"""Test segment list with permission denied"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1058,7 +1059,7 @@ class TestSegmentListAdvancedCases:
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1")
def test_segment_list_dataset_not_found(self, app):
def test_segment_list_dataset_not_found(self, app: Flask):
"""Test segment list with dataset not found"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1079,7 +1080,7 @@ class TestSegmentListAdvancedCases:
class TestSegmentOperationCases:
def test_segment_add_with_provider_token_error(self, app):
def test_segment_add_with_provider_token_error(self, app: Flask):
"""Test segment add with provider token not initialized"""
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -1117,7 +1118,7 @@ class TestSegmentOperationCases:
with pytest.raises(ProviderTokenNotInitError):
method(api, "ds-1", "doc-1")
def test_batch_import_with_document_not_found(self, app):
def test_batch_import_with_document_not_found(self, app: Flask):
"""Test batch import with document not found"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1146,7 +1147,7 @@ class TestSegmentOperationCases:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_batch_import_with_invalid_file(self, app):
def test_batch_import_with_invalid_file(self, app: Flask):
"""Test batch import with invalid file type"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1181,7 +1182,7 @@ class TestSegmentOperationCases:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_batch_import_with_async_task_failure(self, app):
def test_batch_import_with_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1226,7 +1227,7 @@ class TestSegmentOperationCases:
assert status == 500
assert "error" in response
def test_batch_import_get_job_not_found(self, app):
def test_batch_import_get_job_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)

View File

@ -57,7 +57,7 @@ def mock_auth(monkeypatch, current_user):
class TestExternalApiTemplateListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
@ -78,7 +78,7 @@ class TestExternalApiTemplateListApi:
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
def test_post_forbidden(self, app, current_user):
def test_post_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
@ -93,7 +93,7 @@ class TestExternalApiTemplateListApi:
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
def test_post_duplicate_name(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
@ -114,7 +114,7 @@ class TestExternalApiTemplateListApi:
class TestExternalApiTemplateApi:
def test_get_not_found(self, app):
def test_get_not_found(self, app: Flask):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
@ -129,7 +129,7 @@ class TestExternalApiTemplateApi:
with pytest.raises(NotFound):
method(api, "api-id")
def test_delete_forbidden(self, app, current_user):
def test_delete_forbidden(self, app: Flask, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
@ -142,7 +142,7 @@ class TestExternalApiTemplateApi:
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app):
def test_get_scopes_usage_check_to_current_tenant(self, app: Flask):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
@ -162,7 +162,7 @@ class TestExternalApiUseCheckApi:
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
@ -206,7 +206,7 @@ class TestExternalDatasetCreateApi:
assert status == 201
def test_create_forbidden(self, app, current_user):
def test_create_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
@ -226,7 +226,7 @@ class TestExternalDatasetCreateApi:
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app):
def test_hit_testing_dataset_not_found(self, app: Flask):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
@ -241,7 +241,7 @@ class TestExternalKnowledgeHitTestingApi:
with pytest.raises(NotFound):
method(api, "dataset-id")
def test_hit_testing_success(self, app):
def test_hit_testing_success(self, app: Flask):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
@ -266,7 +266,7 @@ class TestExternalKnowledgeHitTestingApi:
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app):
def test_bedrock_retrieval(self, app: Flask):
api = BedrockRetrievalApi()
method = unwrap(api.post)

View File

@ -269,7 +269,7 @@ class TestDatasetMetadataApi:
class TestDatasetMetadataBuiltInFieldApi:
def test_get_built_in_fields(self, app):
def test_get_built_in_fields(self, app: Flask):
api = DatasetMetadataBuiltInFieldApi()
method = unwrap(api.get)

View File

@ -1,6 +1,8 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from flask import Flask
import controllers.console.explore.banner as banner_module
from models.enums import BannerStatus
@ -12,7 +14,7 @@ def unwrap(func):
class TestBannerApi:
def test_get_banners_with_requested_language(self, app):
def test_get_banners_with_requested_language(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)
@ -41,7 +43,7 @@ class TestBannerApi:
}
]
def test_get_banners_fallback_to_en_us(self, app):
def test_get_banners_fallback_to_en_us(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)
@ -76,7 +78,7 @@ class TestBannerApi:
}
]
def test_get_banners_default_language_en_us(self, app):
def test_get_banners_default_language_en_us(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import InternalServerError, NotFound
import controllers.console.explore.message as module
@ -54,7 +55,7 @@ def make_message():
class TestMessageListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -96,7 +97,7 @@ class TestMessageListApi:
with pytest.raises(NotChatAppError):
method(installed_app)
def test_conversation_not_exists(self, app):
def test_conversation_not_exists(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -118,7 +119,7 @@ class TestMessageListApi:
with pytest.raises(NotFound):
method(installed_app)
def test_first_message_not_exists(self, app):
def test_first_message_not_exists(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -142,7 +143,7 @@ class TestMessageListApi:
class TestMessageFeedbackApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
@ -161,7 +162,7 @@ class TestMessageFeedbackApi:
assert result["result"] == "success"
def test_message_not_exists(self, app):
def test_message_not_exists(self, app: Flask):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
@ -182,7 +183,7 @@ class TestMessageFeedbackApi:
class TestMessageMoreLikeThisApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -221,7 +222,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(NotCompletionAppError):
method(installed_app, "mid")
def test_more_like_this_disabled(self, app):
def test_more_like_this_disabled(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -243,7 +244,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(AppMoreLikeThisDisabledError):
method(installed_app, "mid")
def test_message_not_exists_more_like_this(self, app):
def test_message_not_exists_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -265,7 +266,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_more_like_this(self, app):
def test_provider_not_init_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -287,7 +288,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_more_like_this(self, app):
def test_quota_exceeded_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -309,7 +310,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_more_like_this(self, app):
def test_model_not_support_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -331,7 +332,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_more_like_this(self, app):
def test_invoke_error_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -353,7 +354,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_more_like_this(self, app):
def test_unexpected_error_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)

View File

@ -1,5 +1,7 @@
from unittest.mock import MagicMock, patch
from flask import Flask
import controllers.console.explore.recommended_app as module
from models.model import AppMode, IconType
@ -11,7 +13,7 @@ def unwrap(func):
class TestRecommendedAppListApi:
def test_get_with_language_param(self, app):
def test_get_with_language_param(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -31,7 +33,7 @@ class TestRecommendedAppListApi:
service_mock.assert_called_once_with("en-US")
assert result == result_data
def test_get_fallback_to_user_language(self, app):
def test_get_fallback_to_user_language(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -51,7 +53,7 @@ class TestRecommendedAppListApi:
service_mock.assert_called_once_with("fr-FR")
assert result == result_data
def test_get_fallback_to_default_language(self, app):
def test_get_fallback_to_default_language(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -73,7 +75,7 @@ class TestRecommendedAppListApi:
class TestRecommendedAppApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.RecommendedAppApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.saved_message as module
@ -42,7 +43,7 @@ def payload_patch():
class TestSavedMessageListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.SavedMessageListApi()
method = unwrap(api.get)

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import controllers.console.explore.trial as module
@ -88,7 +89,7 @@ def valid_parameters():
class TestTrialAppWorkflowRunApi:
def test_not_workflow_app(self, app):
def test_not_workflow_app(self, app: Flask):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -224,7 +225,7 @@ class TestTrialAppWorkflowRunApi:
class TestTrialChatApi:
def test_not_chat_app(self, app):
def test_not_chat_app(self, app: Flask):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -408,7 +409,7 @@ class TestTrialChatApi:
class TestTrialCompletionApi:
def test_not_completion_app(self, app):
def test_not_completion_app(self, app: Flask):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -560,7 +561,7 @@ class TestTrialCompletionApi:
class TestTrialMessageSuggestedQuestionApi:
def test_not_chat_app(self, app):
def test_not_chat_app(self, app: Flask):
api = module.TrialMessageSuggestedQuestionApi()
method = unwrap(api.get)
@ -952,7 +953,7 @@ class TestTrialAppWorkflowTaskStopApi:
class TestTrialSitApi:
def test_no_site(self, app):
def test_no_site(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
app_model = MagicMock()
@ -963,7 +964,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_archived_tenant(self, app):
def test_archived_tenant(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
@ -978,7 +979,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_success(self, app):
def test_success(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)

View File

@ -73,7 +73,7 @@ def payload_patch():
class TestTagListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = TagListApi()
method = unwrap(api.get)
@ -124,7 +124,7 @@ class TestTagListApi:
assert result["name"] == "test-tag"
assert result["binding_count"] == "0"
def test_post_forbidden(self, app, readonly_user, payload_patch):
def test_post_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagListApi()
method = unwrap(api.post)
@ -170,7 +170,7 @@ class TestTagUpdateDeleteApi:
assert status == 200
assert result["binding_count"] == "3"
def test_patch_forbidden(self, app, readonly_user, payload_patch):
def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
@ -231,7 +231,7 @@ class TestTagBindingCollectionApi:
assert status == 200
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
def test_create_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagBindingCollectionApi()
method = unwrap(api.post)
@ -275,7 +275,7 @@ class TestTagBindingRemoveApi:
assert status == 200
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
def test_remove_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagBindingRemoveApi()
method = unwrap(api.post)

View File

@ -82,7 +82,7 @@ def mock_file_service(mock_db):
class TestFileApiGet:
def test_get_upload_config(self, app):
def test_get_upload_config(self, app: Flask):
api = FileApi()
get_method = unwrap(api.get)
@ -290,7 +290,7 @@ class TestFilePreviewApi:
class TestFileSupportTypeApi:
def test_get_supported_types(self, app):
def test_get_supported_types(self, app: Flask):
api = FileSupportTypeApi()
get_method = unwrap(api.get)

View File

@ -58,7 +58,7 @@ class TestChangeEmailSend:
mock_get_change_data,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
@ -107,7 +107,7 @@ class TestChangeEmailSend:
mock_get_change_data,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
@ -155,7 +155,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
@ -214,7 +214,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
@ -267,7 +267,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
@ -316,7 +316,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
@ -366,7 +366,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
@ -418,7 +418,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
@ -471,7 +471,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
@ -547,7 +547,7 @@ class TestAccountServiceSendChangeEmailEmail:
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask):
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@ -563,7 +563,7 @@ class TestCheckEmailUnique:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask):
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@ -43,7 +43,7 @@ class TestMemberInviteEmailApi:
mock_current_account,
mock_invite_member,
mock_get_features,
app,
app: Flask,
):
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -41,7 +42,7 @@ def unwrap(func):
class TestAccountInitApi:
def test_init_success(self, app):
def test_init_success(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
@ -64,7 +65,7 @@ class TestAccountInitApi:
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
def test_init_already_initialized(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
@ -79,7 +80,7 @@ class TestAccountInitApi:
class TestAccountProfileApi:
def test_get_profile_success(self, app):
def test_get_profile_success(self, app: Flask):
api = AccountProfileApi()
method = unwrap(api.get)
@ -140,7 +141,7 @@ class TestAccountUpdateApis:
class TestAccountAvatarApiGet:
"""GET /account/avatar must not sign arbitrary upload_file IDs (IDOR)."""
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app):
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -172,7 +173,7 @@ class TestAccountAvatarApiGet:
assert result == {"avatar_url": "https://signed/example"}
sign_mock.assert_called_once_with(upload_file_id=file_id)
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app):
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -204,7 +205,7 @@ class TestAccountAvatarApiGet:
sign_mock.assert_not_called()
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app):
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -236,7 +237,7 @@ class TestAccountAvatarApiGet:
sign_mock.assert_not_called()
def test_get_avatar_https_pass_through_without_signing(self, app):
def test_get_avatar_https_pass_through_without_signing(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -263,7 +264,7 @@ class TestAccountAvatarApiGet:
class TestAccountPasswordApi:
def test_password_success(self, app):
def test_password_success(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
@ -292,7 +293,7 @@ class TestAccountPasswordApi:
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
def test_password_wrong_current(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
@ -317,7 +318,7 @@ class TestAccountPasswordApi:
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
def test_get_integrates(self, app: Flask):
api = AccountIntegrateApi()
method = unwrap(api.get)
@ -336,7 +337,7 @@ class TestAccountIntegrateApi:
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
def test_delete_verify_success(self, app: Flask):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
@ -358,7 +359,7 @@ class TestAccountDeleteApi:
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
def test_delete_invalid_code(self, app: Flask):
api = AccountDeleteApi()
method = unwrap(api.post)
@ -379,7 +380,7 @@ class TestAccountDeleteApi:
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
def test_check_email_code_invalid(self, app: Flask):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
@ -405,7 +406,7 @@ class TestChangeEmailApis:
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
def test_reset_email_already_used(self, app: Flask):
api = ChangeEmailResetApi()
method = unwrap(api.post)
@ -427,7 +428,7 @@ class TestChangeEmailApis:
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
def test_email_unique_success(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)
@ -448,7 +449,7 @@ class TestCheckEmailUniqueApi:
assert result["result"] == "success"
def test_email_in_freeze(self, app):
def test_email_in_freeze(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
@ -16,7 +17,7 @@ def unwrap(func):
class TestAgentProviderListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -39,7 +40,7 @@ class TestAgentProviderListApi:
assert result == providers
def test_get_empty_list(self, app):
def test_get_empty_list(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -61,7 +62,7 @@ class TestAgentProviderListApi:
assert result == []
def test_get_account_not_found(self, app):
def test_get_account_not_found(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -77,7 +78,7 @@ class TestAgentProviderListApi:
class TestAgentProviderApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
@ -101,7 +102,7 @@ class TestAgentProviderApi:
assert result == provider_data
def test_get_provider_not_found(self, app):
def test_get_provider_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
@ -124,7 +125,7 @@ class TestAgentProviderApi:
assert result is None
def test_get_account_not_found(self, app):
def test_get_account_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.workspace.endpoint import (
@ -39,7 +40,7 @@ def patch_current_account(user_and_tenant):
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCollectionApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -57,7 +58,7 @@ class TestEndpointCollectionApi:
assert result["success"] is True
def test_create_permission_denied(self, app):
def test_create_permission_denied(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -77,7 +78,7 @@ class TestEndpointCollectionApi:
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
def test_create_validation_error(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -96,7 +97,7 @@ class TestEndpointCollectionApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointCreateApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = DeprecatedEndpointCreateApi()
method = unwrap(api.post)
@ -117,7 +118,7 @@ class TestDeprecatedEndpointCreateApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
@ -130,7 +131,7 @@ class TestEndpointListApi:
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
def test_list_invalid_query(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
@ -143,7 +144,7 @@ class TestEndpointListApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
def test_list_for_plugin_success(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
@ -158,7 +159,7 @@ class TestEndpointListForSinglePluginApi:
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
def test_list_for_plugin_missing_param(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
@ -171,7 +172,7 @@ class TestEndpointListForSinglePluginApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointItemApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
@ -187,7 +188,7 @@ class TestEndpointItemApi:
assert result["success"] is True
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
def test_delete_service_failure(self, app):
def test_delete_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
@ -199,7 +200,7 @@ class TestEndpointItemApi:
assert result["success"] is False
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -226,7 +227,7 @@ class TestEndpointItemApi:
settings={"x": 1},
)
def test_update_validation_error(self, app):
def test_update_validation_error(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -238,7 +239,7 @@ class TestEndpointItemApi:
with pytest.raises(ValueError):
method(api, "e1")
def test_update_service_failure(self, app):
def test_update_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -258,7 +259,7 @@ class TestEndpointItemApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -272,7 +273,7 @@ class TestDeprecatedEndpointDeleteApi:
assert result["success"] is True
def test_delete_invalid_payload(self, app):
def test_delete_invalid_payload(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -282,7 +283,7 @@ class TestDeprecatedEndpointDeleteApi:
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
def test_delete_service_failure(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -299,7 +300,7 @@ class TestDeprecatedEndpointDeleteApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -317,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi:
assert result["success"] is True
def test_update_validation_error(self, app):
def test_update_validation_error(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -329,7 +330,7 @@ class TestDeprecatedEndpointUpdateApi:
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
def test_update_service_failure(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -380,7 +381,7 @@ class TestEndpointRouteMetadata:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
def test_enable_success(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -394,7 +395,7 @@ class TestEndpointEnableApi:
assert result["success"] is True
def test_enable_invalid_payload(self, app):
def test_enable_invalid_payload(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -404,7 +405,7 @@ class TestEndpointEnableApi:
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
def test_enable_service_failure(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -421,7 +422,7 @@ class TestEndpointEnableApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
def test_disable_success(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)
@ -435,7 +436,7 @@ class TestEndpointDisableApi:
assert result["success"] is True
def test_disable_invalid_payload(self, app):
def test_disable_invalid_payload(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
import services
@ -34,7 +35,7 @@ def unwrap(func):
class TestMemberListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = MemberListApi()
method = unwrap(api.get)
@ -59,7 +60,7 @@ class TestMemberListApi:
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
def test_get_no_tenant(self, app: Flask):
api = MemberListApi()
method = unwrap(api.get)
@ -74,7 +75,7 @@ class TestMemberListApi:
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
def test_invite_success(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -101,7 +102,7 @@ class TestMemberInviteEmailApi:
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
def test_invite_limit_exceeded(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -123,7 +124,7 @@ class TestMemberInviteEmailApi:
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
def test_invite_already_member(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -151,7 +152,7 @@ class TestMemberInviteEmailApi:
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
def test_invite_invalid_role(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -166,7 +167,7 @@ class TestMemberInviteEmailApi:
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
def test_invite_generic_exception(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -196,7 +197,7 @@ class TestMemberInviteEmailApi:
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
def test_cancel_success(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -216,7 +217,7 @@ class TestMemberCancelInviteApi:
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
def test_cancel_not_found(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -233,7 +234,7 @@ class TestMemberCancelInviteApi:
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
def test_cancel_cannot_operate_self(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -255,7 +256,7 @@ class TestMemberCancelInviteApi:
assert status == 400
def test_cancel_no_permission(self, app):
def test_cancel_no_permission(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -277,7 +278,7 @@ class TestMemberCancelInviteApi:
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
def test_cancel_member_not_in_tenant(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -301,7 +302,7 @@ class TestMemberCancelInviteApi:
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -324,7 +325,7 @@ class TestMemberUpdateRoleApi:
assert result["result"] == "success"
def test_update_invalid_role(self, app):
def test_update_invalid_role(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -335,7 +336,7 @@ class TestMemberUpdateRoleApi:
assert status == 400
def test_update_member_not_found(self, app):
def test_update_member_not_found(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -354,7 +355,7 @@ class TestMemberUpdateRoleApi:
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
@ -381,7 +382,7 @@ class TestDatasetOperatorMemberListApi:
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
def test_get_no_tenant(self, app: Flask):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
@ -396,7 +397,7 @@ class TestDatasetOperatorMemberListApi:
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
def test_send_success(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -419,7 +420,7 @@ class TestSendOwnerTransferEmailApi:
assert result["result"] == "success"
def test_send_ip_limit(self, app):
def test_send_ip_limit(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -433,7 +434,7 @@ class TestSendOwnerTransferEmailApi:
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
def test_send_not_owner(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -452,7 +453,7 @@ class TestSendOwnerTransferEmailApi:
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
def test_check_invalid_code(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -477,7 +478,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
def test_rate_limited(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -498,7 +499,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
def test_invalid_token(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -520,7 +521,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
def test_invalid_email(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -547,7 +548,7 @@ class TestOwnerTransferCheckApi:
class TestOwnerTransferApi:
def test_transfer_self(self, app):
def test_transfer_self(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
@ -564,7 +565,7 @@ class TestOwnerTransferApi:
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
def test_invalid_token(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
@ -582,7 +583,7 @@ class TestOwnerTransferApi:
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
def test_member_not_in_tenant(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic_core import ValidationError
from werkzeug.exceptions import Forbidden
@ -26,7 +27,7 @@ def unwrap(func):
class TestModelProviderListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ModelProviderListApi()
method = unwrap(api.get)
@ -47,7 +48,7 @@ class TestModelProviderListApi:
class TestModelProviderCredentialApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
@ -66,7 +67,7 @@ class TestModelProviderCredentialApi:
assert "credentials" in result
def test_get_invalid_uuid(self, app):
def test_get_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
@ -80,7 +81,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_post_create_success(self, app):
def test_post_create_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
@ -102,7 +103,7 @@ class TestModelProviderCredentialApi:
assert result["result"] == "success"
assert status == 201
def test_post_create_validation_error(self, app):
def test_post_create_validation_error(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
@ -122,7 +123,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValueError):
method(api, provider="openai")
def test_put_update_success(self, app):
def test_put_update_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
@ -143,7 +144,7 @@ class TestModelProviderCredentialApi:
assert result["result"] == "success"
def test_put_invalid_uuid(self, app):
def test_put_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
@ -159,7 +160,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.delete)
@ -183,7 +184,7 @@ class TestModelProviderCredentialApi:
class TestModelProviderCredentialSwitchApi:
def test_switch_success(self, app):
def test_switch_success(self, app: Flask):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
@ -204,7 +205,7 @@ class TestModelProviderCredentialSwitchApi:
assert result["result"] == "success"
def test_switch_invalid_uuid(self, app):
def test_switch_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
@ -222,7 +223,7 @@ class TestModelProviderCredentialSwitchApi:
class TestModelProviderValidateApi:
def test_validate_success(self, app):
def test_validate_success(self, app: Flask):
api = ModelProviderValidateApi()
method = unwrap(api.post)
@ -243,7 +244,7 @@ class TestModelProviderValidateApi:
assert result["result"] == "success"
def test_validate_failure(self, app):
def test_validate_failure(self, app: Flask):
api = ModelProviderValidateApi()
method = unwrap(api.post)
@ -266,7 +267,7 @@ class TestModelProviderValidateApi:
class TestModelProviderIconApi:
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = ModelProviderIconApi()
with (
@ -280,7 +281,7 @@ class TestModelProviderIconApi:
assert response.mimetype == "image/png"
def test_icon_not_found(self, app):
def test_icon_not_found(self, app: Flask):
api = ModelProviderIconApi()
with (
@ -295,7 +296,7 @@ class TestModelProviderIconApi:
class TestPreferredProviderTypeUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
@ -316,7 +317,7 @@ class TestPreferredProviderTypeUpdateApi:
assert result["result"] == "success"
def test_invalid_enum(self, app):
def test_invalid_enum(self, app: Flask):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
@ -334,7 +335,7 @@ class TestPreferredProviderTypeUpdateApi:
class TestModelProviderPaymentCheckoutUrlApi:
def test_checkout_success(self, app):
def test_checkout_success(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
@ -359,7 +360,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
assert "url" in result
def test_invalid_provider(self, app):
def test_invalid_provider(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
@ -367,7 +368,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
with pytest.raises(ValueError):
method(api, provider="openai")
def test_permission_denied(self, app):
def test_permission_denied(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)

View File

@ -72,7 +72,7 @@ class TestDefaultModelApi:
assert result["result"] == "success"
def test_get_returns_empty_when_no_default(self, app):
def test_get_returns_empty_when_no_default(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.get)
@ -154,7 +154,7 @@ class TestModelProviderModelApi:
assert status == 204
def test_get_models_returns_empty(self, app):
def test_get_models_returns_empty(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.get)
@ -224,7 +224,7 @@ class TestModelProviderModelCredentialApi:
assert status == 201
def test_get_empty_credentials(self, app):
def test_get_empty_credentials(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
@ -242,7 +242,7 @@ class TestModelProviderModelCredentialApi:
assert result["credentials"] == {}
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.delete)
@ -416,7 +416,7 @@ class TestParameterAndAvailableModels:
assert "data" in result
def test_empty_rules(self, app):
def test_empty_rules(self, app: Flask):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
@ -431,7 +431,7 @@ class TestParameterAndAvailableModels:
assert result["data"] == []
def test_no_models(self, app):
def test_no_models(self, app: Flask):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ import io
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
@ -61,7 +62,7 @@ def tenant():
class TestPluginListLatestVersionsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginListLatestVersionsApi()
method = unwrap(api.post)
@ -77,7 +78,7 @@ class TestPluginListLatestVersionsApi:
assert "versions" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginListLatestVersionsApi()
method = unwrap(api.post)
@ -95,7 +96,7 @@ class TestPluginListLatestVersionsApi:
class TestPluginDebuggingKeyApi:
def test_debugging_key_success(self, app):
def test_debugging_key_success(self, app: Flask):
api = PluginDebuggingKeyApi()
method = unwrap(api.get)
@ -108,7 +109,7 @@ class TestPluginDebuggingKeyApi:
assert result["key"] == "k"
def test_debugging_key_error(self, app):
def test_debugging_key_error(self, app: Flask):
api = PluginDebuggingKeyApi()
method = unwrap(api.get)
@ -125,7 +126,7 @@ class TestPluginDebuggingKeyApi:
class TestPluginListApi:
def test_plugin_list(self, app):
def test_plugin_list(self, app: Flask):
api = PluginListApi()
method = unwrap(api.get)
@ -142,7 +143,7 @@ class TestPluginListApi:
class TestPluginIconApi:
def test_plugin_icon(self, app):
def test_plugin_icon(self, app: Flask):
api = PluginIconApi()
method = unwrap(api.get)
@ -156,7 +157,7 @@ class TestPluginIconApi:
class TestPluginAssetApi:
def test_plugin_asset(self, app):
def test_plugin_asset(self, app: Flask):
api = PluginAssetApi()
method = unwrap(api.get)
@ -171,7 +172,7 @@ class TestPluginAssetApi:
class TestPluginUploadFromPkgApi:
def test_upload_pkg_success(self, app):
def test_upload_pkg_success(self, app: Flask):
api = PluginUploadFromPkgApi()
method = unwrap(api.post)
@ -188,7 +189,7 @@ class TestPluginUploadFromPkgApi:
assert result["ok"] is True
def test_upload_pkg_too_large(self, app):
def test_upload_pkg_too_large(self, app: Flask):
api = PluginUploadFromPkgApi()
method = unwrap(api.post)
@ -210,7 +211,7 @@ class TestPluginUploadFromPkgApi:
class TestPluginInstallFromPkgApi:
def test_install_from_pkg(self, app):
def test_install_from_pkg(self, app: Flask):
api = PluginInstallFromPkgApi()
method = unwrap(api.post)
@ -229,7 +230,7 @@ class TestPluginInstallFromPkgApi:
class TestPluginUninstallApi:
def test_uninstall(self, app):
def test_uninstall(self, app: Flask):
api = PluginUninstallApi()
method = unwrap(api.post)
@ -246,7 +247,7 @@ class TestPluginUninstallApi:
class TestPluginChangePermissionApi:
def test_change_permission_forbidden(self, app):
def test_change_permission_forbidden(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
@ -264,7 +265,7 @@ class TestPluginChangePermissionApi:
with pytest.raises(Forbidden):
method(api)
def test_change_permission_success(self, app):
def test_change_permission_success(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
@ -286,7 +287,7 @@ class TestPluginChangePermissionApi:
class TestPluginFetchPermissionApi:
def test_fetch_permission_default(self, app):
def test_fetch_permission_default(self, app: Flask):
api = PluginFetchPermissionApi()
method = unwrap(api.get)
@ -319,7 +320,7 @@ class TestPluginFetchDynamicSelectOptionsApi:
class TestPluginReadmeApi:
def test_fetch_readme(self, app):
def test_fetch_readme(self, app: Flask):
api = PluginReadmeApi()
method = unwrap(api.get)
@ -334,7 +335,7 @@ class TestPluginReadmeApi:
class TestPluginListInstallationsFromIdsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -352,7 +353,7 @@ class TestPluginListInstallationsFromIdsApi:
assert "plugins" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -371,7 +372,7 @@ class TestPluginListInstallationsFromIdsApi:
class TestPluginUploadFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -388,7 +389,7 @@ class TestPluginUploadFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -407,7 +408,7 @@ class TestPluginUploadFromGithubApi:
class TestPluginUploadFromBundleApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUploadFromBundleApi()
method = unwrap(api.post)
@ -430,7 +431,7 @@ class TestPluginUploadFromBundleApi:
assert result["ok"] is True
def test_too_large(self, app):
def test_too_large(self, app: Flask):
api = PluginUploadFromBundleApi()
method = unwrap(api.post)
@ -458,7 +459,7 @@ class TestPluginUploadFromBundleApi:
class TestPluginInstallFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginInstallFromGithubApi()
method = unwrap(api.post)
@ -478,7 +479,7 @@ class TestPluginInstallFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginInstallFromGithubApi()
method = unwrap(api.post)
@ -502,7 +503,7 @@ class TestPluginInstallFromGithubApi:
class TestPluginInstallFromMarketplaceApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginInstallFromMarketplaceApi()
method = unwrap(api.post)
@ -520,7 +521,7 @@ class TestPluginInstallFromMarketplaceApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginInstallFromMarketplaceApi()
method = unwrap(api.post)
@ -539,7 +540,7 @@ class TestPluginInstallFromMarketplaceApi:
class TestPluginFetchMarketplacePkgApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchMarketplacePkgApi()
method = unwrap(api.get)
@ -552,7 +553,7 @@ class TestPluginFetchMarketplacePkgApi:
assert "manifest" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchMarketplacePkgApi()
method = unwrap(api.get)
@ -569,7 +570,7 @@ class TestPluginFetchMarketplacePkgApi:
class TestPluginFetchManifestApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchManifestApi()
method = unwrap(api.get)
@ -585,7 +586,7 @@ class TestPluginFetchManifestApi:
assert "manifest" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchManifestApi()
method = unwrap(api.get)
@ -602,7 +603,7 @@ class TestPluginFetchManifestApi:
class TestPluginFetchInstallTasksApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchInstallTasksApi()
method = unwrap(api.get)
@ -615,7 +616,7 @@ class TestPluginFetchInstallTasksApi:
assert "tasks" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchInstallTasksApi()
method = unwrap(api.get)
@ -632,7 +633,7 @@ class TestPluginFetchInstallTasksApi:
class TestPluginFetchInstallTaskApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchInstallTaskApi()
method = unwrap(api.get)
@ -645,7 +646,7 @@ class TestPluginFetchInstallTaskApi:
assert "task" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchInstallTaskApi()
method = unwrap(api.get)
@ -662,7 +663,7 @@ class TestPluginFetchInstallTaskApi:
class TestPluginDeleteInstallTaskApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteInstallTaskApi()
method = unwrap(api.post)
@ -675,7 +676,7 @@ class TestPluginDeleteInstallTaskApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteInstallTaskApi()
method = unwrap(api.post)
@ -692,7 +693,7 @@ class TestPluginDeleteInstallTaskApi:
class TestPluginDeleteAllInstallTaskItemsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteAllInstallTaskItemsApi()
method = unwrap(api.post)
@ -707,7 +708,7 @@ class TestPluginDeleteAllInstallTaskItemsApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteAllInstallTaskItemsApi()
method = unwrap(api.post)
@ -724,7 +725,7 @@ class TestPluginDeleteAllInstallTaskItemsApi:
class TestPluginDeleteInstallTaskItemApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteInstallTaskItemApi()
method = unwrap(api.post)
@ -737,7 +738,7 @@ class TestPluginDeleteInstallTaskItemApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteInstallTaskItemApi()
method = unwrap(api.post)
@ -754,7 +755,7 @@ class TestPluginDeleteInstallTaskItemApi:
class TestPluginUpgradeFromMarketplaceApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -775,7 +776,7 @@ class TestPluginUpgradeFromMarketplaceApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -797,7 +798,7 @@ class TestPluginUpgradeFromMarketplaceApi:
class TestPluginUpgradeFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUpgradeFromGithubApi()
method = unwrap(api.post)
@ -821,7 +822,7 @@ class TestPluginUpgradeFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUpgradeFromGithubApi()
method = unwrap(api.post)
@ -846,7 +847,7 @@ class TestPluginUpgradeFromGithubApi:
class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
@ -873,7 +874,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
assert result["options"] == [1]
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
@ -901,7 +902,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
class TestPluginChangePreferencesApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginChangePreferencesApi()
method = unwrap(api.post)
@ -931,7 +932,7 @@ class TestPluginChangePreferencesApi:
assert result["success"] is True
def test_permission_fail(self, app):
def test_permission_fail(self, app: Flask):
api = PluginChangePreferencesApi()
method = unwrap(api.post)
@ -962,7 +963,7 @@ class TestPluginChangePreferencesApi:
class TestPluginFetchPreferencesApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchPreferencesApi()
method = unwrap(api.get)
@ -996,7 +997,7 @@ class TestPluginFetchPreferencesApi:
class TestPluginAutoUpgradeExcludePluginApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginAutoUpgradeExcludePluginApi()
method = unwrap(api.post)
@ -1011,7 +1012,7 @@ class TestPluginAutoUpgradeExcludePluginApi:
assert result["success"] is True
def test_fail(self, app):
def test_fail(self, app: Flask):
api = PluginAutoUpgradeExcludePluginApi()
method = unwrap(api.post)

View File

@ -2,6 +2,7 @@ from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Unauthorized
@ -37,7 +38,7 @@ def unwrap(func):
class TestTenantListApi:
def test_get_success_saas_path(self, app):
def test_get_success_saas_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -85,7 +86,7 @@ class TestTenantListApi:
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app: Flask):
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
@ -140,7 +141,7 @@ class TestTenantListApi:
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_called_once_with("t2")
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask):
"""Test fallback to FeatureService when bulk billing returns empty result.
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
@ -197,7 +198,7 @@ class TestTenantListApi:
assert get_features_mock.call_count == 2
logger_warning_mock.assert_called_once()
def test_get_billing_disabled_community_path(self, app):
def test_get_billing_disabled_community_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -236,7 +237,7 @@ class TestTenantListApi:
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
def test_get_enterprise_only_skips_feature_service(self, app):
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -276,7 +277,7 @@ class TestTenantListApi:
assert result["workspaces"][1]["current"] is True
get_features_mock.assert_not_called()
def test_get_enterprise_only_with_empty_tenants(self, app):
def test_get_enterprise_only_with_empty_tenants(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -302,7 +303,7 @@ class TestTenantListApi:
class TestWorkspaceListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
@ -324,7 +325,7 @@ class TestWorkspaceListApi:
assert result["total"] == 1
assert result["has_more"] is False
def test_get_has_next_true(self, app):
def test_get_has_next_true(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
@ -355,7 +356,7 @@ class TestWorkspaceListApi:
class TestTenantApi:
def test_post_active_tenant(self, app):
def test_post_active_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -375,7 +376,7 @@ class TestTenantApi:
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app):
def test_post_archived_with_switch(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -397,7 +398,7 @@ class TestTenantApi:
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app):
def test_post_archived_no_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -411,7 +412,7 @@ class TestTenantApi:
with pytest.raises(Unauthorized):
method(api)
def test_post_info_path(self, app):
def test_post_info_path(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -454,7 +455,7 @@ class TestTenantInfoResponse:
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
def test_switch_success(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -477,7 +478,7 @@ class TestSwitchWorkspaceApi:
assert result["result"] == "success"
def test_switch_not_linked(self, app):
def test_switch_not_linked(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -493,7 +494,7 @@ class TestSwitchWorkspaceApi:
with pytest.raises(AccountNotLinkTenantError):
method(api)
def test_switch_tenant_not_found(self, app):
def test_switch_tenant_not_found(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -515,7 +516,7 @@ class TestSwitchWorkspaceApi:
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
@ -538,7 +539,7 @@ class TestCustomConfigWorkspaceApi:
assert result["result"] == "success"
def test_logo_fallback(self, app):
def test_logo_fallback(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
@ -569,7 +570,7 @@ class TestCustomConfigWorkspaceApi:
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app):
def test_no_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -582,7 +583,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(NoFileUploadedError):
method(api)
def test_too_many_files(self, app):
def test_too_many_files(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -601,7 +602,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(TooManyFilesError):
method(api)
def test_invalid_extension(self, app):
def test_invalid_extension(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -616,7 +617,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(UnsupportedFileTypeError):
method(api)
def test_upload_success(self, app):
def test_upload_success(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -648,7 +649,7 @@ class TestWebappLogoWorkspaceApi:
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app):
def test_filename_missing(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -672,7 +673,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(FilenameNotExistsError):
method(api)
def test_file_too_large(self, app):
def test_file_too_large(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -701,7 +702,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(FileTooLargeError):
method(api)
def test_service_unsupported_file(self, app):
def test_service_unsupported_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -732,7 +733,7 @@ class TestWebappLogoWorkspaceApi:
class TestWorkspaceInfoApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
@ -756,7 +757,7 @@ class TestWorkspaceInfoApi:
assert result["result"] == "success"
def test_no_current_tenant(self, app):
def test_no_current_tenant(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
@ -774,7 +775,7 @@ class TestWorkspaceInfoApi:
class TestWorkspacePermissionApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)
@ -799,7 +800,7 @@ class TestWorkspacePermissionApi:
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app):
def test_no_current_tenant(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)

View File

@ -41,7 +41,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_chat_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving parameters for a chat app."""
# Arrange
@ -91,7 +91,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_workflow_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving parameters for a workflow app."""
# Arrange
@ -136,7 +136,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_chat_config_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test that AppUnavailableError is raised when chat app has no config."""
# Arrange
@ -174,7 +174,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_workflow_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
# Arrange
@ -234,7 +234,14 @@ class TestAppMetaApi:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.app.app.AppService")
def test_get_app_meta(
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self,
mock_app_service,
mock_db,
mock_validate_token,
mock_current_app,
mock_user_logged_in,
app: Flask,
mock_app_model,
):
"""Test retrieving app metadata via AppService."""
# Arrange
@ -310,7 +317,7 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving basic app information."""
mock_current_app.login_manager = Mock()
@ -402,7 +409,9 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
def test_get_app_info_with_no_tags(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask
):
"""Test retrieving app info when app has no tags."""
# Arrange
mock_current_app.login_manager = Mock()
@ -453,7 +462,7 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_returns_correct_mode(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, app_mode
):
"""Test that all app modes are correctly returned."""
# Arrange

View File

@ -13,6 +13,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
@ -190,7 +191,7 @@ class TestAudioServiceMockedBehavior:
class TestAudioApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = AudioApi()
handler = _unwrap(api.post)
@ -216,7 +217,7 @@ class TestAudioApi:
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = AudioApi()
handler = _unwrap(api.post)
@ -227,7 +228,7 @@ class TestAudioApi:
with pytest.raises(expected):
handler(api, app_model=app_model, end_user=end_user)
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_unhandled_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
)
@ -242,7 +243,7 @@ class TestAudioApi:
class TestTextApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = TextApi()
@ -259,7 +260,7 @@ class TestTextApi:
assert response == {"audio": "ok"}
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
)

View File

@ -16,6 +16,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
@ -295,7 +296,7 @@ class TestCompletionControllerLogic:
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask):
"""Test CompletionApi.post success path."""
from controllers.service_api.app.completion import CompletionApi
@ -320,7 +321,7 @@ class TestCompletionControllerLogic:
mock_generate_service.generate.assert_called_once()
@patch("controllers.service_api.app.completion.service_api_ns")
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask):
"""Test CompletionApi.post with wrong app mode."""
from controllers.service_api.app.completion import CompletionApi
@ -334,7 +335,7 @@ class TestCompletionControllerLogic:
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask):
"""Test ChatApi.post success path."""
from controllers.service_api.app.completion import ChatApi
@ -355,7 +356,7 @@ class TestCompletionControllerLogic:
assert response == {"text": "compacted"}
@patch("controllers.service_api.app.completion.service_api_ns")
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask):
"""Test ChatApi.post with wrong app mode."""
from controllers.service_api.app.completion import ChatApi
@ -368,7 +369,7 @@ class TestCompletionControllerLogic:
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.AppTaskService")
def test_completion_stop_api_success(self, mock_task_service, app):
def test_completion_stop_api_success(self, mock_task_service, app: Flask):
"""Test CompletionStopApi.post success."""
from controllers.service_api.app.completion import CompletionStopApi
@ -385,7 +386,7 @@ class TestCompletionControllerLogic:
mock_task_service.stop_task.assert_called_once()
@patch("controllers.service_api.app.completion.AppTaskService")
def test_chat_stop_api_success(self, mock_task_service, app):
def test_chat_stop_api_success(self, mock_task_service, app: Flask):
"""Test ChatStopApi.post success."""
from controllers.service_api.app.completion import ChatStopApi

View File

@ -20,6 +20,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
import services
@ -504,7 +505,7 @@ class TestConversationApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_list_last_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _BeginStub:
def __enter__(self):
return SimpleNamespace()
@ -552,7 +553,7 @@ class TestConversationDetailApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"delete",
@ -570,7 +571,7 @@ class TestConversationDetailApiController:
class TestConversationRenameApiController:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"rename",
@ -602,7 +603,7 @@ class TestConversationVariablesApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
@ -621,7 +622,7 @@ class TestConversationVariablesApiController:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
@ -661,7 +662,7 @@ class TestConversationVariablesApiController:
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_type_mismatch(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
@ -687,7 +688,7 @@ class TestConversationVariableDetailApiController:
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
@ -713,7 +714,7 @@ class TestConversationVariableDetailApiController:
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,

View File

@ -16,6 +16,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from controllers.common.errors import (
FilenameNotExistsError,
@ -282,7 +283,7 @@ class TestFileApiPost:
assert status == 201
mock_file_svc_cls.return_value.upload_file.assert_called_once()
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
def test_upload_no_file(self, app: Flask, mock_app_model, mock_end_user):
"""Test NoFileUploadedError when no file in request."""
from controllers.service_api.app.file import FileApi
@ -296,7 +297,7 @@ class TestFileApiPost:
with pytest.raises(NoFileUploadedError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user):
"""Test TooManyFilesError when multiple files uploaded."""
from io import BytesIO
@ -317,7 +318,7 @@ class TestFileApiPost:
with pytest.raises(TooManyFilesError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user):
"""Test UnsupportedFileTypeError when file has no mimetype."""
from io import BytesIO

View File

@ -11,6 +11,7 @@ from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
from flask import Flask
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
@ -281,7 +282,7 @@ class TestHitlServiceApi:
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
self, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",

View File

@ -19,6 +19,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api.app.error import NotChatAppError
@ -390,7 +391,7 @@ class TestMessageListApi:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_conversation_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
@ -409,7 +410,7 @@ class TestMessageListApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_first_message_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
@ -430,7 +431,7 @@ class TestMessageListApi:
class TestMessageFeedbackApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"create_feedback",
@ -452,7 +453,7 @@ class TestMessageFeedbackApi:
class TestAppGetFeedbacksApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
api = AppGetFeedbacksApi()
@ -476,7 +477,7 @@ class TestMessageSuggestedApi:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -492,7 +493,7 @@ class TestMessageSuggestedApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_disabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -508,7 +509,7 @@ class TestMessageSuggestedApi:
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_internal_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -524,7 +525,7 @@ class TestMessageSuggestedApi:
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",

View File

@ -20,6 +20,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.service_api.app.error import NotWorkflowAppError
@ -366,7 +367,7 @@ class TestWorkflowRunRepository:
class TestWorkflowRunDetailApi:
def test_not_workflow_app(self, app) -> None:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -397,7 +398,7 @@ class TestWorkflowRunDetailApi:
class TestWorkflowRunApi:
def test_not_workflow_app(self, app) -> None:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -407,7 +408,7 @@ class TestWorkflowRunApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_rate_limit(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -425,7 +426,7 @@ class TestWorkflowRunApi:
class TestWorkflowRunByIdApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -441,7 +442,7 @@ class TestWorkflowRunByIdApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_draft_workflow(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -459,7 +460,7 @@ class TestWorkflowRunByIdApi:
class TestWorkflowTaskStopApi:
def test_wrong_mode(self, app) -> None:
def test_wrong_mode(self, app: Flask) -> None:
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -469,7 +470,7 @@ class TestWorkflowTaskStopApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
send_mock = Mock()
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
@ -489,7 +490,7 @@ class TestWorkflowTaskStopApi:
class TestWorkflowAppLogApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _BeginStub:
def __enter__(self):
return SimpleNamespace()
@ -557,7 +558,7 @@ class TestWorkflowRunDetailApiGet:
self,
mock_db,
mock_repo_factory,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow run detail retrieval."""
@ -579,7 +580,7 @@ class TestWorkflowRunDetailApiGet:
assert result["status"] == "succeeded"
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
def test_get_workflow_run_wrong_app_mode(self, mock_db, app: Flask):
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
from controllers.service_api.app.workflow import WorkflowRunDetailApi
@ -604,7 +605,7 @@ class TestWorkflowTaskStopApiPost:
self,
mock_queue_mgr,
mock_graph_mgr,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow task stop."""
@ -624,7 +625,7 @@ class TestWorkflowTaskStopApiPost:
mock_graph_mgr.assert_called_once()
mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1")
def test_stop_workflow_task_wrong_app_mode(self, app):
def test_stop_workflow_task_wrong_app_mode(self, app: Flask):
"""Test NotWorkflowAppError when app mode is not workflow."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
@ -649,7 +650,7 @@ class TestWorkflowAppLogApiGet:
self,
mock_db,
mock_wf_svc_cls,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow log retrieval."""

View File

@ -9,6 +9,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.app.error import NotWorkflowAppError
@ -41,7 +42,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
@ -52,7 +53,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_permission_denied(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -70,7 +71,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_finished_run_returns_sse(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -103,7 +104,7 @@ class TestWorkflowEventsApi:
assert payload["task_id"] == "run-1"
assert payload["event"] == "workflow_finished"
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_running_run_streams_events(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -135,7 +136,7 @@ class TestWorkflowEventsApi:
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_running_run_with_snapshot(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",

View File

@ -23,6 +23,7 @@ from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden, NotFound
@ -373,7 +374,7 @@ class TestDatasourcePluginsApiGet:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_success(self, mock_svc_cls, mock_db, app):
def test_get_plugins_success(self, mock_svc_cls, mock_db, app: Flask):
"""Test successful retrieval of datasource plugins."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
@ -396,7 +397,7 @@ class TestDatasourcePluginsApiGet:
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_get_plugins_not_found(self, mock_db, app):
def test_get_plugins_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -407,7 +408,7 @@ class TestDatasourcePluginsApiGet:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app):
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app: Flask):
"""Test empty plugin list."""
mock_db.session.scalar.return_value = Mock()
mock_svc_instance = Mock()
@ -439,7 +440,7 @@ class TestDatasourceNodeRunApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app):
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app: Flask):
"""Test successful datasource node run."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
@ -473,7 +474,7 @@ class TestDatasourceNodeRunApiPost:
mock_svc_instance.run_datasource_workflow_node.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
def test_post_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -488,7 +489,7 @@ class TestDatasourceNodeRunApiPost:
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app):
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app: Flask):
"""Test AssertionError when current_user is not an Account instance."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
@ -549,7 +550,7 @@ class TestPipelineRunApiPost:
mock_gen_svc.generate.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
def test_post_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -561,7 +562,7 @@ class TestPipelineRunApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app):
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app: Flask):
"""Test Forbidden when current_user is not an Account."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
@ -585,7 +586,7 @@ class TestFileUploadApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app):
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app: Flask):
"""Test successful file upload."""
mock_current_user.__bool__ = Mock(return_value=True)
@ -621,7 +622,7 @@ class TestFileUploadApiPost:
assert response["name"] == "doc.pdf"
assert response["extension"] == "pdf"
def test_upload_no_file(self, app):
def test_upload_no_file(self, app: Flask):
"""Test error when no file is uploaded."""
with app.test_request_context(
"/datasets/pipeline/file-upload",

View File

@ -18,6 +18,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.segment import (
@ -782,7 +783,7 @@ class TestSegmentApiGet:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -893,7 +894,7 @@ class TestSegmentApiPost:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -946,7 +947,7 @@ class TestSegmentApiPost:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -989,7 +990,7 @@ class TestSegmentApiPost:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1041,7 +1042,7 @@ class TestDatasetSegmentApiDelete:
mock_doc_svc,
mock_dataset_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1086,7 +1087,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_doc_svc,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1282,7 +1283,7 @@ class TestDatasetSegmentApiUpdate:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1322,7 +1323,7 @@ class TestDatasetSegmentApiUpdate:
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1374,7 +1375,7 @@ class TestDatasetSegmentApiGetSingle:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1421,7 +1422,7 @@ class TestDatasetSegmentApiGetSingle:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1460,7 +1461,7 @@ class TestDatasetSegmentApiGetSingle:
self,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1491,7 +1492,7 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1526,7 +1527,7 @@ class TestDatasetSegmentApiGetSingle:
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1570,7 +1571,7 @@ class TestChildChunkApiGet:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1609,7 +1610,7 @@ class TestChildChunkApiGet:
self,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1638,7 +1639,7 @@ class TestChildChunkApiGet:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1670,7 +1671,7 @@ class TestChildChunkApiGet:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1729,7 +1730,7 @@ class TestChildChunkApiPost:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1771,7 +1772,7 @@ class TestChildChunkApiPost:
mock_feature_svc,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1809,7 +1810,7 @@ class TestChildChunkApiPost:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1863,7 +1864,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1913,7 +1914,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1954,7 +1955,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1994,7 +1995,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):

View File

@ -19,6 +19,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.metadata import (
@ -76,7 +77,7 @@ class TestDatasetMetadataCreatePost:
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -106,7 +107,7 @@ class TestDatasetMetadataCreatePost:
def test_create_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -136,7 +137,7 @@ class TestDatasetMetadataCreateGet:
self,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -160,7 +161,7 @@ class TestDatasetMetadataCreateGet:
def test_get_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -201,7 +202,7 @@ class TestDatasetMetadataServiceApiPatch:
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -232,7 +233,7 @@ class TestDatasetMetadataServiceApiPatch:
def test_update_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -273,7 +274,7 @@ class TestDatasetMetadataServiceApiDelete:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -302,7 +303,7 @@ class TestDatasetMetadataServiceApiDelete:
def test_delete_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -336,7 +337,7 @@ class TestDatasetMetadataBuiltInFieldGet:
def test_get_built_in_fields_success(
self,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -382,7 +383,7 @@ class TestDatasetMetadataBuiltInFieldAction:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -414,7 +415,7 @@ class TestDatasetMetadataBuiltInFieldAction:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -441,7 +442,7 @@ class TestDatasetMetadataBuiltInFieldAction:
def test_action_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -485,7 +486,7 @@ class TestDocumentMetadataEditPost:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -513,7 +514,7 @@ class TestDocumentMetadataEditPost:
def test_update_documents_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):

View File

@ -5,6 +5,7 @@ Unit tests for Service API Index endpoint
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.service_api.index import IndexApi
@ -13,7 +14,7 @@ class TestIndexApi:
"""Test suite for IndexApi resource."""
@patch("controllers.service_api.index.dify_config", autospec=True)
def test_get_returns_api_info(self, mock_config, app):
def test_get_returns_api_info(self, mock_config, app: Flask):
"""Test that GET returns API metadata with correct structure."""
# Arrange
mock_config.project.version = "1.0.0-test"
@ -32,7 +33,7 @@ class TestIndexApi:
assert response["api_version"] == "v1"
assert response["server_version"] == "1.0.0-test"
def test_get_response_has_required_fields(self, app):
def test_get_response_has_required_fields(self, app: Flask):
"""Test that response contains all required fields."""
# Arrange
mock_config = MagicMock()

View File

@ -39,7 +39,7 @@ class TestValidateAndGetApiToken:
app.config["TESTING"] = True
return app
def test_missing_authorization_header(self, app):
def test_missing_authorization_header(self, app: Flask):
"""Test that Unauthorized is raised when Authorization header is missing."""
# Arrange
with app.test_request_context("/", method="GET"):
@ -50,7 +50,7 @@ class TestValidateAndGetApiToken:
validate_and_get_api_token("app")
assert "Authorization header must be provided" in str(exc_info.value)
def test_invalid_auth_scheme(self, app):
def test_invalid_auth_scheme(self, app: Flask):
"""Test that Unauthorized is raised when auth scheme is not Bearer."""
# Arrange
with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}):
@ -62,7 +62,7 @@ class TestValidateAndGetApiToken:
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask):
"""Test that valid token returns the ApiToken object."""
# Arrange
mock_api_token = Mock(spec=ApiToken)
@ -84,7 +84,7 @@ class TestValidateAndGetApiToken:
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask):
"""Test that invalid token raises Unauthorized."""
# Arrange
from werkzeug.exceptions import Unauthorized
@ -161,7 +161,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app no longer exists."""
# Arrange
mock_api_token = Mock()
@ -182,7 +182,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app status is abnormal."""
# Arrange
mock_api_token = Mock()
@ -205,7 +205,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app API is disabled."""
# Arrange
mock_api_token = Mock()
@ -240,7 +240,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app):
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that request is allowed when under resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -264,7 +264,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app):
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that Forbidden is raised when at resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -287,7 +287,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app):
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that request is allowed when billing is disabled."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -320,7 +320,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app):
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that add_segment is rejected in SANDBOX plan."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -342,7 +342,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app):
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that non-add_segment operations are allowed in SANDBOX."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -376,7 +376,7 @@ class TestCloudEditionBillingRateLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app):
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app: Flask):
"""Test that request is allowed when within rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -406,7 +406,7 @@ class TestCloudEditionBillingRateLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
@patch("controllers.service_api.wraps.db")
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app):
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app: Flask):
"""Test that Forbidden is raised when over rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -445,7 +445,7 @@ class TestValidateDatasetToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.current_app")
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app):
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask):
"""Test that valid dataset token allows access."""
# Arrange
# Use standard Mock for login_manager
@ -487,7 +487,7 @@ class TestValidateDatasetToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app):
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app: Flask):
"""Test that NotFound is raised when dataset doesn't exist."""
# Arrange
mock_api_token = Mock()

View File

@ -13,7 +13,7 @@ class TestExternalDataFetch:
app = Flask(__name__)
return app
def test_fetch_success(self, app):
def test_fetch_success(self, app: Flask):
with app.app_context():
fetcher = ExternalDataFetch()
@ -79,7 +79,7 @@ class TestExternalDataFetch:
assert result_inputs == inputs
assert result_inputs is not inputs # Should be a copy
def test_fetch_with_none_variable(self, app):
def test_fetch_with_none_variable(self, app: Flask):
with app.app_context():
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
@ -95,7 +95,7 @@ class TestExternalDataFetch:
assert "var1" not in result_inputs
assert result_inputs == {"in": "val"}
def test_query_external_data_tool(self, app):
def test_query_external_data_tool(self, app: Flask):
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})