Merge remote-tracking branch 'upstream/feat/hitl-form-enhancement' into feat/hitl-form-enhancement

This commit is contained in:
QuantumGhost 2026-05-08 11:44:37 +08:00
commit ed98925f11
234 changed files with 6937 additions and 10451 deletions

View File

@ -109,6 +109,8 @@ jobs:
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check

View File

@ -75,14 +75,15 @@ console_ns.schema_model(
def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
match value:
case FileSegment():
return value.value.model_dump()
case ArrayFileSegment():
return [i.model_dump() for i in value.value]
case SegmentGroup():
return [_convert_values_to_json_serializable_object(i) for i in value.value]
case _:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable):

View File

@ -1,9 +0,0 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)

View File

@ -90,7 +90,7 @@ class TestOAuthLogin:
mock_redirect,
mock_get_providers,
resource,
app,
app: Flask,
mock_oauth_provider,
invite_token,
expected_token,
@ -165,7 +165,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -218,7 +218,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -262,7 +262,7 @@ class TestOAuthCallback:
mock_tenant_service,
mock_account_service,
resource,
app,
app: Flask,
oauth_setup,
account_status,
expected_redirect,
@ -301,7 +301,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -337,7 +337,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
"""Defensive test for CLOSED account status handling in OAuth callback.
@ -466,7 +466,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
allow_register,
@ -505,7 +505,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
@ -530,7 +530,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_tenant_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
):

View File

@ -47,7 +47,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
):
# Arrange
@ -105,7 +105,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -154,7 +154,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
"""
Test successful verification code validation.
@ -201,7 +201,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
@ -345,7 +345,7 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
app,
app: Flask,
mock_account,
):
"""

View File

@ -30,7 +30,7 @@ class TestPipelineTemplateListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateListApi()
method = unwrap(api.get)
@ -54,7 +54,7 @@ class TestPipelineTemplateDetailApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -75,7 +75,7 @@ class TestPipelineTemplateDetailApi:
assert status == 200
assert response == template
def test_get_returns_404_when_template_not_found(self, app):
def test_get_returns_404_when_template_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -94,7 +94,7 @@ class TestPipelineTemplateDetailApi:
assert status == 404
assert "error" in response
def test_get_returns_404_for_customized_type_not_found(self, app):
def test_get_returns_404_for_customized_type_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -119,7 +119,7 @@ class TestCustomizedPipelineTemplateApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
@ -141,7 +141,7 @@ class TestCustomizedPipelineTemplateApi:
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
@ -156,7 +156,7 @@ class TestCustomizedPipelineTemplateApi:
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app, db_session_with_containers: Session):
def test_post_success(self, app: Flask, db_session_with_containers: Session):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -183,7 +183,7 @@ class TestCustomizedPipelineTemplateApi:
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
def test_post_template_not_found(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -197,7 +197,7 @@ class TestPublishCustomizedPipelineTemplateApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)

View File

@ -36,7 +36,7 @@ class TestRagPipelineImportApi:
"name": "Test",
}
def test_post_success_200(self, app):
def test_post_success_200(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -66,7 +66,7 @@ class TestRagPipelineImportApi:
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
def test_post_failed_400(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -96,7 +96,7 @@ class TestRagPipelineImportApi:
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
def test_post_pending_202(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -132,7 +132,7 @@ class TestRagPipelineImportConfirmApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_confirm_success(self, app):
def test_confirm_success(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -160,7 +160,7 @@ class TestRagPipelineImportConfirmApi:
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
def test_confirm_failed(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -194,7 +194,7 @@ class TestRagPipelineImportCheckDependenciesApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
@ -223,7 +223,7 @@ class TestRagPipelineExportApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_with_include_secret(self, app):
def test_get_with_include_secret(self, app: Flask):
api = RagPipelineExportApi()
method = unwrap(api.get)

View File

@ -391,7 +391,7 @@ class TestPublishedPipelineApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi()

View File

@ -55,7 +55,7 @@ class TestDataSourceApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -79,7 +79,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
def test_get_no_bindings(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -95,7 +95,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -116,7 +116,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -137,7 +137,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -152,7 +152,7 @@ class TestDataSourceApi:
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -169,7 +169,7 @@ class TestDataSourceApi:
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -192,7 +192,7 @@ class TestDataSourceNotionListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_credential_not_found(self, app, patch_tenant):
def test_get_credential_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -206,7 +206,7 @@ class TestDataSourceNotionListApi:
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -247,7 +247,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -300,7 +300,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -327,7 +327,7 @@ class TestDataSourceNotionApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_preview_success(self, app, patch_tenant):
def test_get_preview_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
@ -348,7 +348,7 @@ class TestDataSourceNotionApi:
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
@ -385,7 +385,7 @@ class TestDataSourceNotionDatasetSyncApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -408,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi:
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -428,7 +428,7 @@ class TestDataSourceNotionDocumentSyncApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
@ -451,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi:
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
def test_get_document_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)

View File

@ -57,7 +57,7 @@ class TestConversationListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
def test_get_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -82,7 +82,7 @@ class TestConversationListApi:
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app, chat_app, user):
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -98,7 +98,7 @@ class TestConversationListApi:
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app, non_chat_app):
def test_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -112,7 +112,7 @@ class TestConversationApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
def test_delete_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -130,7 +130,7 @@ class TestConversationApi:
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app, chat_app, user):
def test_delete_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -146,7 +146,7 @@ class TestConversationApi:
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app, non_chat_app):
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -160,7 +160,7 @@ class TestConversationRenameApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
def test_rename_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -179,7 +179,7 @@ class TestConversationRenameApi:
assert result["id"] == "cid"
def test_rename_not_found(self, app, chat_app, user):
def test_rename_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -201,7 +201,7 @@ class TestConversationPinApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
def test_pin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
@ -223,7 +223,7 @@ class TestConversationUnPinApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):
def test_unpin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)

View File

@ -49,7 +49,7 @@ class TestTriggerProviderApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = TriggerProviderIconApi()
method = unwrap(api.get)
@ -63,7 +63,7 @@ class TestTriggerProviderApis:
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
def test_list_providers(self, app: Flask):
api = TriggerProviderListApi()
method = unwrap(api.get)
@ -77,7 +77,7 @@ class TestTriggerProviderApis:
):
assert method(api) == []
def test_provider_info(self, app):
def test_provider_info(self, app: Flask):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
@ -97,7 +97,7 @@ class TestTriggerSubscriptionListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -111,7 +111,7 @@ class TestTriggerSubscriptionListApi:
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
def test_list_invalid_provider(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -132,7 +132,7 @@ class TestTriggerSubscriptionBuilderApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create_builder(self, app):
def test_create_builder(self, app: Flask):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
@ -147,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis:
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
def test_get_builder(self, app: Flask):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
@ -160,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
def test_verify_builder(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -174,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
def test_verify_builder_error(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -189,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis:
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
def test_update_builder(self, app: Flask):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
@ -203,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
def test_logs(self, app: Flask):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
@ -220,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
def test_build(self, app: Flask):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
@ -240,7 +240,7 @@ class TestTriggerSubscriptionCrud:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_update_rename_only(self, app):
def test_update_rename_only(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -259,7 +259,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
def test_update_not_found(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -274,7 +274,7 @@ class TestTriggerSubscriptionCrud:
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
def test_update_rebuild(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -297,7 +297,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
def test_delete_subscription(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -320,7 +320,7 @@ class TestTriggerSubscriptionCrud:
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
def test_delete_subscription_value_error(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -346,7 +346,7 @@ class TestTriggerOAuthApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_authorize_success(self, app):
def test_oauth_authorize_success(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -373,7 +373,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
def test_oauth_authorize_no_client(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -388,7 +388,7 @@ class TestTriggerOAuthApis:
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
def test_oauth_callback_forbidden(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -396,7 +396,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
def test_oauth_callback_success(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -426,7 +426,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
def test_oauth_callback_no_oauth_client(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -450,7 +450,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
def test_oauth_callback_empty_credentials(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -484,7 +484,7 @@ class TestTriggerOAuthClientManageApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_client(self, app):
def test_get_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
@ -511,7 +511,7 @@ class TestTriggerOAuthClientManageApi:
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
def test_post_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -525,7 +525,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
def test_delete_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
@ -539,7 +539,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
def test_oauth_client_post_value_error(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -560,7 +560,7 @@ class TestTriggerSubscriptionVerifyApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_verify_success(self, app):
def test_verify_success(self, app: Flask):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)

View File

@ -291,7 +291,7 @@ class TestDatasetListApiGet:
mock_current_user,
mock_provider_mgr,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -326,7 +326,7 @@ class TestDatasetListApiPost:
mock_dataset_svc,
mock_current_user,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -352,7 +352,7 @@ class TestDatasetListApiPost:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -390,7 +390,7 @@ class TestDatasetApiGet:
mock_provider_mgr,
mock_marshal,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -440,7 +440,7 @@ class TestDatasetApiGet:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -468,7 +468,7 @@ class TestDatasetApiDelete:
mock_dataset_svc,
mock_current_user,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -490,7 +490,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -511,7 +511,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -543,7 +543,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -574,7 +574,7 @@ class TestDocumentStatusApiPatch:
def test_batch_update_status_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -603,7 +603,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -636,7 +636,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -669,7 +669,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -709,7 +709,7 @@ class TestDatasetTagsApiGet:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -731,7 +731,7 @@ class TestDatasetTagsApiGet:
def test_list_tags_from_db(
self,
mock_current_user,
app,
app: Flask,
db_session_with_containers: Session,
):
"""Integration test: creates real Tag rows and retrieves them
@ -774,7 +774,7 @@ class TestDatasetTagsApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -797,7 +797,7 @@ class TestDatasetTagsApiPost:
mock_tag_svc.save_tags.assert_called_once()
@patch("controllers.service_api.dataset.dataset.current_user")
def test_create_tag_forbidden(self, mock_current_user, app):
def test_create_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -826,7 +826,7 @@ class TestDatasetTagsApiPatch:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -852,7 +852,7 @@ class TestDatasetTagsApiPatch:
mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_update_tag_forbidden(self, mock_current_user, app):
def test_update_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -880,7 +880,7 @@ class TestDatasetTagsApiDelete:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -905,7 +905,7 @@ class TestDatasetTagsApiDelete:
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
@patch("libs.login.current_user")
def test_delete_tag_forbidden(self, mock_current_user, app):
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
user_obj = Mock(spec=Account)
@ -933,7 +933,7 @@ class TestDatasetTagsBindingStatusApi:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi
@ -963,7 +963,7 @@ class TestDatasetTagBindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
@ -988,7 +988,7 @@ class TestDatasetTagBindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_bind_tags_forbidden(self, mock_current_user, app):
def test_bind_tags_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
mock_current_user.__class__ = Account
@ -1014,7 +1014,7 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1044,7 +1044,7 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1069,7 +1069,7 @@ class TestDatasetTagUnbindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_unbind_tag_forbidden(self, mock_current_user, app):
def test_unbind_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
mock_current_user.__class__ = Account

View File

@ -240,7 +240,7 @@ class TestDecodeJwtToken:
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
@ -300,7 +300,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
@ -325,7 +325,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
@ -351,7 +351,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)

View File

@ -21,6 +21,8 @@ from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
from extensions.ext_redis import redis_client
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
TenantAndAccount = tuple[Tenant, Account]
@dataclass
class TestTask:
@ -74,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration:
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
def test_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
def secondary_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
@ -95,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker):
def test_tenant_isolation(
self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker
):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
@ -115,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
def test_key_isolation(self, test_tenant_and_account: TenantAndAccount):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
@ -293,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker):
def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
@ -436,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker):
def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test compatibility with legacy queues containing only string data.
@ -466,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker):
def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test complete migration scenario from legacy to new system.
@ -547,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker):
def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test error recovery when legacy queue contains malformed data.

View File

@ -0,0 +1,94 @@
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
def _tenant_id() -> str:
return str(uuid4())
def _get_permission(session: Session, tenant_id: str) -> TenantPluginPermission | None:
session.expire_all()
stmt = select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
return session.scalars(stmt).one_or_none()
def _count_permissions(session: Session, tenant_id: str) -> int:
stmt = select(func.count()).select_from(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
return session.scalar(stmt) or 0
class TestGetPermission:
"""Integration tests for PluginPermissionService.get_permission using testcontainers."""
def test_returns_permission_when_found(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
permission = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
db_session_with_containers.add(permission)
db_session_with_containers.commit()
result = PluginPermissionService.get_permission(tenant_id)
assert result is not None
assert result.id == permission.id
assert result.tenant_id == tenant_id
assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_returns_none_when_not_found(self, db_session_with_containers: Session):
result = PluginPermissionService.get_permission(_tenant_id())
assert result is None
class TestChangePermission:
"""Integration tests for PluginPermissionService.change_permission using testcontainers."""
def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
result = PluginPermissionService.change_permission(
tenant_id,
TenantPluginPermission.InstallPermission.EVERYONE,
TenantPluginPermission.DebugPermission.EVERYONE,
)
permission = _get_permission(db_session_with_containers, tenant_id)
assert result is True
assert permission is not None
assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_updates_existing_permission(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
existing = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
db_session_with_containers.add(existing)
db_session_with_containers.commit()
result = PluginPermissionService.change_permission(
tenant_id,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.ADMINS,
)
permission = _get_permission(db_session_with_containers, tenant_id)
assert result is True
assert permission is not None
assert permission.id == existing.id
assert permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
assert _count_permissions(db_session_with_containers, tenant_id) == 1

View File

@ -7,6 +7,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import App
from models.model import EndUser
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@ -184,7 +185,7 @@ class TestAppGenerateService:
return app, account
def _create_test_workflow(self, db_session_with_containers: Session, app):
def _create_test_workflow(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test workflow for testing.

View File

@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration:
return app
def _create_conversation(self, db_session_with_containers: Session, app):
def _create_conversation(self, db_session_with_containers: Session, app: App):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import App, CreatorUserRole
from models.enums import ConversationFromSource
from models.model import EndUser, Message
from models.web import SavedMessage
@ -88,7 +89,7 @@ class TestSavedMessageService:
return app, account
def _create_test_end_user(self, db_session_with_containers: Session, app):
def _create_test_end_user(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test end user for testing.
@ -116,7 +117,7 @@ class TestSavedMessageService:
return end_user
def _create_test_message(self, db_session_with_containers: Session, app, user):
def _create_test_message(self, db_session_with_containers: Session, app: App, user):
"""
Helper method to create a test message for testing.
@ -199,13 +200,13 @@ class TestSavedMessageService:
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -272,13 +273,13 @@ class TestSavedMessageService:
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="end_user",
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="end_user",
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
@ -449,7 +450,7 @@ class TestSavedMessageService:
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -540,7 +541,9 @@ class TestSavedMessageService:
message = self._create_test_message(db_session_with_containers, app, account)
# Pre-create a saved message
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id)
saved = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id
)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
@ -571,7 +574,9 @@ class TestSavedMessageService:
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id)
saved = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id
)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
@ -596,10 +601,10 @@ class TestSavedMessageService:
# Both users save the same message
saved_account = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account1.id
)
saved_end_user = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id
)
db_session_with_containers.add_all([saved_account, saved_end_user])
db_session_with_containers.commit()

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models import Account, App
from models.enums import ConversationFromSource
from models.model import Conversation, EndUser
from models.web import PinnedConversation
@ -93,7 +93,7 @@ class TestWebConversationService:
return app, account
def _create_test_end_user(self, db_session_with_containers: Session, app):
def _create_test_end_user(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test end user for testing.

View File

@ -185,7 +185,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):
@ -263,7 +263,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):
@ -312,7 +312,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
language,
@ -358,7 +358,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
):
"""
@ -398,7 +398,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
):
"""
@ -438,7 +438,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):

View File

@ -195,7 +195,7 @@ class TestEmailCodeLoginSendEmailApi:
mock_get_user,
mock_is_ip_limit,
mock_db,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -267,7 +267,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -315,7 +315,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -431,7 +431,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""
@ -474,7 +474,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""
@ -515,7 +515,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""

View File

@ -412,7 +412,7 @@ class TestLoginApi:
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -448,7 +448,7 @@ class TestLoginApi:
mock_revoke_token,
mock_get_token_data,
mock_db,
app,
app: Flask,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")

View File

@ -74,7 +74,7 @@ class TestRefreshTokenApi:
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
def test_refresh_fails_without_token(self, mock_extract_token, app):
def test_refresh_fails_without_token(self, mock_extract_token, app: Flask):
"""
Test token refresh failure when no refresh token provided.
@ -98,7 +98,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh failure with invalid refresh token.
@ -123,7 +123,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh failure with expired refresh token.
@ -148,7 +148,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh with empty string token.

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import console_ns
@ -29,7 +30,7 @@ def unwrap(func):
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -61,7 +62,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
assert response.status_code == 200
def test_get_no_oauth_config(self, app):
def test_get_no_oauth_config(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -80,7 +81,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_without_credential_id_sets_cookie(self, app):
def test_get_without_credential_id_sets_cookie(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -115,7 +116,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app):
def test_callback_success_new_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -157,7 +158,7 @@ class TestDatasourceOAuthCallback:
assert response.status_code == 302
def test_callback_missing_context(self, app):
def test_callback_missing_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -165,7 +166,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_invalid_context(self, app):
def test_callback_invalid_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -180,7 +181,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_oauth_config_not_found(self, app):
def test_callback_oauth_config_not_found(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -202,7 +203,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(NotFound):
method(api, "notion")
def test_callback_reauthorize_existing_credential(self, app):
def test_callback_reauthorize_existing_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -245,7 +246,7 @@ class TestDatasourceOAuthCallback:
assert response.status_code == 302
assert "/oauth-callback" in response.location
def test_callback_context_id_from_cookie(self, app):
def test_callback_context_id_from_cookie(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -289,7 +290,7 @@ class TestDatasourceOAuthCallback:
class TestDatasourceAuth:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -312,7 +313,7 @@ class TestDatasourceAuth:
assert status == 200
def test_post_invalid_credentials(self, app):
def test_post_invalid_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -334,7 +335,7 @@ class TestDatasourceAuth:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
@ -355,7 +356,7 @@ class TestDatasourceAuth:
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app):
def test_post_missing_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -372,7 +373,7 @@ class TestDatasourceAuth:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_empty_list(self, app):
def test_get_empty_list(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
@ -395,7 +396,7 @@ class TestDatasourceAuth:
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
@ -418,7 +419,7 @@ class TestDatasourceAuthDeleteApi:
assert status == 200
def test_delete_missing_credential_id(self, app):
def test_delete_missing_credential_id(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
@ -437,7 +438,7 @@ class TestDatasourceAuthDeleteApi:
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -460,7 +461,7 @@ class TestDatasourceAuthUpdateApi:
assert status == 201
def test_update_with_credentials_none(self, app):
def test_update_with_credentials_none(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -484,7 +485,7 @@ class TestDatasourceAuthUpdateApi:
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app):
def test_update_name_only(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -507,7 +508,7 @@ class TestDatasourceAuthUpdateApi:
assert status == 201
def test_update_with_empty_credentials_dict(self, app):
def test_update_with_empty_credentials_dict(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -533,7 +534,7 @@ class TestDatasourceAuthUpdateApi:
class TestDatasourceAuthListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
@ -553,7 +554,7 @@ class TestDatasourceAuthListApi:
assert status == 200
def test_auth_list_empty(self, app):
def test_auth_list_empty(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
@ -574,7 +575,7 @@ class TestDatasourceAuthListApi:
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app):
def test_hardcode_list_empty(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
@ -597,7 +598,7 @@ class TestDatasourceAuthListApi:
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
@ -619,7 +620,7 @@ class TestDatasourceHardCodeAuthListApi:
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -642,7 +643,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
@ -662,7 +663,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_post_empty_payload(self, app):
def test_post_empty_payload(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -685,7 +686,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_post_disabled_flag(self, app):
def test_post_disabled_flag(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -714,7 +715,7 @@ class TestDatasourceAuthOauthCustomClient:
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app):
def test_set_default_success(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
@ -737,7 +738,7 @@ class TestDatasourceAuthDefaultApi:
assert status == 200
def test_default_missing_id(self, app):
def test_default_missing_id(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
@ -756,7 +757,7 @@ class TestDatasourceAuthDefaultApi:
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app):
def test_update_name_success(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
@ -779,7 +780,7 @@ class TestDatasourceUpdateProviderNameApi:
assert status == 200
def test_update_name_too_long(self, app):
def test_update_name_too_long(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
@ -799,7 +800,7 @@ class TestDatasourceUpdateProviderNameApi:
with pytest.raises(ValueError):
method(api, "notion")
def test_update_name_missing_credential_id(self, app):
def test_update_name_missing_credential_id(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
@ -25,7 +26,7 @@ class TestDataSourceContentPreviewApi:
"credential_id": "cred-1",
}
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -66,7 +67,7 @@ class TestDataSourceContentPreviewApi:
assert status == 200
assert response == preview_result
def test_post_forbidden_non_account_user(self, app):
def test_post_forbidden_non_account_user(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -85,7 +86,7 @@ class TestDataSourceContentPreviewApi:
with pytest.raises(Forbidden):
method(api, pipeline, "node-1")
def test_post_invalid_payload(self, app):
def test_post_invalid_payload(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -108,7 +109,7 @@ class TestDataSourceContentPreviewApi:
with pytest.raises(ValueError):
method(api, pipeline, "node-1")
def test_post_without_credential_id(self, app):
def test_post_without_credential_id(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)

View File

@ -2,6 +2,7 @@ import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import services
@ -58,7 +59,7 @@ class TestDatasetList:
user.is_dataset_editor = True
return user
def test_get_success_basic(self, app):
def test_get_success_basic(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -93,7 +94,7 @@ class TestDatasetList:
assert resp["total"] == 1
assert resp["data"][0]["embedding_available"] is True
def test_get_with_ids_filter(self, app):
def test_get_with_ids_filter(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -128,7 +129,7 @@ class TestDatasetList:
assert status == 200
assert resp["total"] == 2
def test_get_with_tag_ids(self, app):
def test_get_with_tag_ids(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -161,7 +162,7 @@ class TestDatasetList:
assert status == 200
def test_embedding_available_false(self, app):
def test_embedding_available_false(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -203,7 +204,7 @@ class TestDatasetList:
assert resp["data"][0]["embedding_available"] is False
def test_partial_members_permission(self, app):
def test_partial_members_permission(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -242,7 +243,7 @@ class TestDatasetList:
class TestDatasetListApiPost:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -290,7 +291,7 @@ class TestDatasetListApiPost:
assert status == 201
def test_post_forbidden(self, app):
def test_post_forbidden(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -310,7 +311,7 @@ class TestDatasetListApiPost:
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
def test_post_duplicate_name(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -335,7 +336,7 @@ class TestDatasetListApiPost:
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload_missing_name(self, app):
def test_post_invalid_payload_missing_name(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -343,7 +344,7 @@ class TestDatasetListApiPost:
with pytest.raises(ValueError):
method(api)
def test_post_invalid_indexing_technique(self, app):
def test_post_invalid_indexing_technique(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -356,7 +357,7 @@ class TestDatasetListApiPost:
with pytest.raises(ValueError, match="Invalid indexing technique"):
method(api)
def test_post_invalid_provider(self, app):
def test_post_invalid_provider(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -371,7 +372,7 @@ class TestDatasetListApiPost:
class TestDatasetApiGet:
def test_get_success_basic(self, app):
def test_get_success_basic(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -427,7 +428,7 @@ class TestDatasetApiGet:
assert status == 200
assert data["embedding_available"] is True
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -448,7 +449,7 @@ class TestDatasetApiGet:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -475,7 +476,7 @@ class TestDatasetApiGet:
with pytest.raises(Forbidden, match="no access"):
method(api, dataset_id)
def test_get_high_quality_embedding_unavailable(self, app):
def test_get_high_quality_embedding_unavailable(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -530,7 +531,7 @@ class TestDatasetApiGet:
assert data["embedding_available"] is False
def test_get_partial_members_permission(self, app):
def test_get_partial_members_permission(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -590,7 +591,7 @@ class TestDatasetApiGet:
class TestDatasetApiPatch:
def test_patch_success_basic(self, app):
def test_patch_success_basic(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -659,7 +660,7 @@ class TestDatasetApiPatch:
assert status == 200
assert result["partial_member_list"] == []
def test_patch_dataset_not_found(self, app):
def test_patch_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -674,7 +675,7 @@ class TestDatasetApiPatch:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, "missing")
def test_patch_permission_denied(self, app):
def test_patch_permission_denied(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -704,7 +705,7 @@ class TestDatasetApiPatch:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_patch_partial_members_update(self, app):
def test_patch_partial_members_update(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -773,7 +774,7 @@ class TestDatasetApiPatch:
assert result["partial_member_list"] == payload["partial_member_list"]
def test_patch_clear_partial_members(self, app):
def test_patch_clear_partial_members(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -843,7 +844,7 @@ class TestDatasetApiPatch:
class TestDatasetApiDelete:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -874,7 +875,7 @@ class TestDatasetApiDelete:
assert status == 204
assert result == {"result": "success"}
def test_delete_forbidden_no_permission(self, app):
def test_delete_forbidden_no_permission(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -893,7 +894,7 @@ class TestDatasetApiDelete:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_delete_dataset_not_found(self, app):
def test_delete_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -917,7 +918,7 @@ class TestDatasetApiDelete:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_delete_dataset_in_use(self, app):
def test_delete_dataset_in_use(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -943,7 +944,7 @@ class TestDatasetApiDelete:
class TestDatasetUseCheckApi:
def test_get_use_check_true(self, app):
def test_get_use_check_true(self, app: Flask):
api = DatasetUseCheckApi()
method = unwrap(api.get)
@ -962,7 +963,7 @@ class TestDatasetUseCheckApi:
assert status == 200
assert result == {"is_using": True}
def test_get_use_check_false(self, app):
def test_get_use_check_false(self, app: Flask):
api = DatasetUseCheckApi()
method = unwrap(api.get)
@ -983,7 +984,7 @@ class TestDatasetUseCheckApi:
class TestDatasetQueryApi:
def test_get_queries_success(self, app):
def test_get_queries_success(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1027,7 +1028,7 @@ class TestDatasetQueryApi:
assert response["has_more"] is False
assert len(response["data"]) == 2
def test_get_queries_dataset_not_found(self, app):
def test_get_queries_dataset_not_found(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1049,7 +1050,7 @@ class TestDatasetQueryApi:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_queries_permission_denied(self, app):
def test_get_queries_permission_denied(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1078,7 +1079,7 @@ class TestDatasetQueryApi:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_get_queries_pagination_has_more(self, app):
def test_get_queries_pagination_has_more(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1152,7 +1153,7 @@ class TestDatasetIndexingEstimateApi:
"dataset_id": None,
}
def test_post_success_upload_file(self, app):
def test_post_success_upload_file(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
@ -1193,7 +1194,7 @@ class TestDatasetIndexingEstimateApi:
assert status == 200
assert response == {"tokens": 100}
def test_post_file_not_found(self, app):
def test_post_file_not_found(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
@ -1223,7 +1224,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(NotFound):
method(api)
def test_post_llm_bad_request_error(self, app):
def test_post_llm_bad_request_error(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1258,7 +1259,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_provider_token_not_init(self, app):
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1293,7 +1294,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_generic_exception(self, app):
def test_post_generic_exception(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1330,7 +1331,7 @@ class TestDatasetIndexingEstimateApi:
class TestDatasetRelatedAppListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1368,7 +1369,7 @@ class TestDatasetRelatedAppListApi:
assert response["total"] == 2
assert response["data"] == [app1, app2]
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1386,7 +1387,7 @@ class TestDatasetRelatedAppListApi:
with pytest.raises(NotFound):
method(api, "dataset-1")
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1410,7 +1411,7 @@ class TestDatasetRelatedAppListApi:
with pytest.raises(Forbidden):
method(api, "dataset-1")
def test_get_filters_none_apps(self, app):
def test_get_filters_none_apps(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1449,7 +1450,7 @@ class TestDatasetRelatedAppListApi:
class TestDatasetIndexingStatusApi:
def test_get_success_with_documents(self, app):
def test_get_success_with_documents(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1490,7 +1491,7 @@ class TestDatasetIndexingStatusApi:
assert item["completed_segments"] == 3
assert item["total_segments"] == 3
def test_get_success_no_documents(self, app):
def test_get_success_no_documents(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1510,7 +1511,7 @@ class TestDatasetIndexingStatusApi:
assert status == 200
assert response == {"data": []}
def test_segment_counts_different_values(self, app):
def test_segment_counts_different_values(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1550,7 +1551,7 @@ class TestDatasetIndexingStatusApi:
class TestDatasetApiKeyApi:
def test_get_api_keys_success(self, app):
def test_get_api_keys_success(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.get)
@ -1587,7 +1588,7 @@ class TestDatasetApiKeyApi:
assert response["data"][1]["id"] == "key-2"
assert response["data"][1]["token"] == "ds-def"
def test_post_create_api_key_success(self, app):
def test_post_create_api_key_success(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.post)
@ -1632,7 +1633,7 @@ class TestDatasetApiKeyApi:
assert response["type"] == "dataset"
assert response["created_at"] is not None
def test_post_exceed_max_keys(self, app):
def test_post_exceed_max_keys(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.post)
@ -1658,7 +1659,7 @@ class TestDatasetApiKeyApi:
class TestDatasetApiDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
@ -1688,7 +1689,7 @@ class TestDatasetApiDeleteApi:
assert status == 204
assert response["result"] == "success"
def test_delete_key_not_found(self, app):
def test_delete_key_not_found(self, app: Flask):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
@ -1708,7 +1709,7 @@ class TestDatasetApiDeleteApi:
class TestDatasetEnableApiApi:
def test_enable_api(self, app):
def test_enable_api(self, app: Flask):
api = DatasetEnableApiApi()
method = unwrap(api.post)
@ -1724,7 +1725,7 @@ class TestDatasetEnableApiApi:
assert status == 200
assert response["result"] == "success"
def test_disable_api(self, app):
def test_disable_api(self, app: Flask):
api = DatasetEnableApiApi()
method = unwrap(api.post)
@ -1742,7 +1743,7 @@ class TestDatasetEnableApiApi:
class TestDatasetApiBaseUrlApi:
def test_get_api_base_url_from_config(self, app):
def test_get_api_base_url_from_config(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1757,7 +1758,7 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "https://example.com/v1"
def test_get_api_base_url_from_request(self, app):
def test_get_api_base_url_from_request(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1772,7 +1773,7 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "http://localhost:5000/v1"
def test_get_api_base_url_no_double_v1(self, app):
def test_get_api_base_url_no_double_v1(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1789,7 +1790,7 @@ class TestDatasetApiBaseUrlApi:
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRetrievalSettingApi()
method = unwrap(api.get)
@ -1810,7 +1811,7 @@ class TestDatasetRetrievalSettingApi:
class TestDatasetRetrievalSettingMockApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRetrievalSettingMockApi()
method = unwrap(api.get)
@ -1827,7 +1828,7 @@ class TestDatasetRetrievalSettingMockApi:
class TestDatasetErrorDocs:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetErrorDocs()
method = unwrap(api.get)
@ -1850,7 +1851,7 @@ class TestDatasetErrorDocs:
assert status == 200
assert response["total"] == 1
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetErrorDocs()
method = unwrap(api.get)
@ -1866,7 +1867,7 @@ class TestDatasetErrorDocs:
class TestDatasetPermissionUserListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
@ -1897,7 +1898,7 @@ class TestDatasetPermissionUserListApi:
assert status == 200
assert response["data"] == users
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
@ -1923,7 +1924,7 @@ class TestDatasetPermissionUserListApi:
class TestDatasetAutoDisableLogApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
@ -1946,7 +1947,7 @@ class TestDatasetAutoDisableLogApi:
assert status == 200
assert response == logs
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -239,7 +240,7 @@ class TestDatasetDocumentListApi:
assert "documents" in response
def test_post_forbidden(self, app):
def test_post_forbidden(self, app: Flask):
api = DatasetDocumentListApi()
method = unwrap(api.post)
@ -395,7 +396,7 @@ class TestDocumentDownloadApi:
class TestDocumentProcessingApi:
def test_processing_forbidden_when_not_editor(self, app):
def test_processing_forbidden_when_not_editor(self, app: Flask):
api = DocumentProcessingApi()
method = unwrap(api.patch)
@ -1185,7 +1186,7 @@ class TestDocumentPermissionCases:
"preview": [],
}
def test_document_tenant_mismatch(self, app):
def test_document_tenant_mismatch(self, app: Flask):
api = DocumentApi()
method = unwrap(api.get)
@ -1253,7 +1254,7 @@ class TestDocumentPermissionCases:
assert status == 200
assert response["mode"] == "custom"
def test_process_rule_permission_denied(self, app):
def test_process_rule_permission_denied(self, app: Flask):
api = GetProcessRuleApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -82,7 +83,7 @@ def test_get_segment_with_summary(monkeypatch):
class TestDatasetDocumentSegmentListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -132,7 +133,7 @@ class TestDatasetDocumentSegmentListApi:
assert status == 200
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -150,7 +151,7 @@ class TestDatasetDocumentSegmentListApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -176,7 +177,7 @@ class TestDatasetDocumentSegmentListApi:
class TestDatasetDocumentSegmentApi:
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -221,7 +222,7 @@ class TestDatasetDocumentSegmentApi:
assert status == 200
assert response["result"] == "success"
def test_patch_document_indexing_in_progress(self, app):
def test_patch_document_indexing_in_progress(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -264,7 +265,7 @@ class TestDatasetDocumentSegmentApi:
with pytest.raises(InvalidActionError):
method(api, "ds-1", "doc-1", "disable")
def test_patch_llm_bad_request(self, app):
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -308,7 +309,7 @@ class TestDatasetDocumentSegmentApi:
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1", "enable")
def test_patch_provider_token_not_init(self, app):
def test_patch_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -354,7 +355,7 @@ class TestDatasetDocumentSegmentApi:
class TestDatasetDocumentSegmentAddApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -413,7 +414,7 @@ class TestDatasetDocumentSegmentAddApi:
assert status == 200
assert response["data"]["id"] == "seg-1"
def test_post_llm_bad_request(self, app):
def test_post_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -452,7 +453,7 @@ class TestDatasetDocumentSegmentAddApi:
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1")
def test_post_provider_token_not_init(self, app):
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -493,7 +494,7 @@ class TestDatasetDocumentSegmentAddApi:
class TestDatasetDocumentSegmentUpdateApi:
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
@ -551,7 +552,7 @@ class TestDatasetDocumentSegmentUpdateApi:
assert status == 200
assert "data" in response
def test_patch_llm_bad_request(self, app):
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
@ -596,7 +597,7 @@ class TestDatasetDocumentSegmentUpdateApi:
class TestDatasetDocumentSegmentBatchImportApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -638,7 +639,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
assert status == 200
assert response["job_status"] == "waiting"
def test_post_dataset_not_found(self, app):
def test_post_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -659,7 +660,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_document_not_found(self, app):
def test_post_document_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -684,7 +685,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_upload_file_not_found(self, app):
def test_post_upload_file_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -713,7 +714,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_invalid_file_type(self, app):
def test_post_invalid_file_type(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -745,7 +746,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(ValueError):
method(api, "ds-1", "doc-1")
def test_post_async_task_failure(self, app):
def test_post_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -783,7 +784,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
assert status == 500
assert "error" in response
def test_get_job_not_found_in_redis(self, app):
def test_get_job_not_found_in_redis(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)
@ -799,7 +800,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
class TestChildChunkAddApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
@ -852,7 +853,7 @@ class TestChildChunkAddApi:
assert status == 200
assert response["data"]["id"] == "cc-1"
def test_post_child_chunk_indexing_error(self, app):
def test_post_child_chunk_indexing_error(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
@ -897,7 +898,7 @@ class TestChildChunkAddApi:
class TestChildChunkUpdateApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
@ -941,7 +942,7 @@ class TestChildChunkUpdateApi:
assert status == 204
assert response["result"] == "success"
def test_delete_child_chunk_index_error(self, app):
def test_delete_child_chunk_index_error(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
@ -984,7 +985,7 @@ class TestChildChunkUpdateApi:
class TestSegmentListAdvancedCases:
def test_segment_list_with_keyword_filter(self, app):
def test_segment_list_with_keyword_filter(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1035,7 +1036,7 @@ class TestSegmentListAdvancedCases:
assert status == 200
assert response["total"] == 1
def test_segment_list_permission_denied(self, app):
def test_segment_list_permission_denied(self, app: Flask):
"""Test segment list with permission denied"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1058,7 +1059,7 @@ class TestSegmentListAdvancedCases:
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1")
def test_segment_list_dataset_not_found(self, app):
def test_segment_list_dataset_not_found(self, app: Flask):
"""Test segment list with dataset not found"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1079,7 +1080,7 @@ class TestSegmentListAdvancedCases:
class TestSegmentOperationCases:
def test_segment_add_with_provider_token_error(self, app):
def test_segment_add_with_provider_token_error(self, app: Flask):
"""Test segment add with provider token not initialized"""
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -1117,7 +1118,7 @@ class TestSegmentOperationCases:
with pytest.raises(ProviderTokenNotInitError):
method(api, "ds-1", "doc-1")
def test_batch_import_with_document_not_found(self, app):
def test_batch_import_with_document_not_found(self, app: Flask):
"""Test batch import with document not found"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1146,7 +1147,7 @@ class TestSegmentOperationCases:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_batch_import_with_invalid_file(self, app):
def test_batch_import_with_invalid_file(self, app: Flask):
"""Test batch import with invalid file type"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1181,7 +1182,7 @@ class TestSegmentOperationCases:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_batch_import_with_async_task_failure(self, app):
def test_batch_import_with_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1226,7 +1227,7 @@ class TestSegmentOperationCases:
assert status == 500
assert "error" in response
def test_batch_import_get_job_not_found(self, app):
def test_batch_import_get_job_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)

View File

@ -57,7 +57,7 @@ def mock_auth(monkeypatch, current_user):
class TestExternalApiTemplateListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
@ -78,7 +78,7 @@ class TestExternalApiTemplateListApi:
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
def test_post_forbidden(self, app, current_user):
def test_post_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
@ -93,7 +93,7 @@ class TestExternalApiTemplateListApi:
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
def test_post_duplicate_name(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
@ -114,7 +114,7 @@ class TestExternalApiTemplateListApi:
class TestExternalApiTemplateApi:
def test_get_not_found(self, app):
def test_get_not_found(self, app: Flask):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
@ -129,7 +129,7 @@ class TestExternalApiTemplateApi:
with pytest.raises(NotFound):
method(api, "api-id")
def test_delete_forbidden(self, app, current_user):
def test_delete_forbidden(self, app: Flask, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
@ -142,7 +142,7 @@ class TestExternalApiTemplateApi:
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app):
def test_get_scopes_usage_check_to_current_tenant(self, app: Flask):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
@ -162,7 +162,7 @@ class TestExternalApiUseCheckApi:
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
@ -206,7 +206,7 @@ class TestExternalDatasetCreateApi:
assert status == 201
def test_create_forbidden(self, app, current_user):
def test_create_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
@ -226,7 +226,7 @@ class TestExternalDatasetCreateApi:
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app):
def test_hit_testing_dataset_not_found(self, app: Flask):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
@ -241,7 +241,7 @@ class TestExternalKnowledgeHitTestingApi:
with pytest.raises(NotFound):
method(api, "dataset-id")
def test_hit_testing_success(self, app):
def test_hit_testing_success(self, app: Flask):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
@ -266,7 +266,7 @@ class TestExternalKnowledgeHitTestingApi:
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app):
def test_bedrock_retrieval(self, app: Flask):
api = BedrockRetrievalApi()
method = unwrap(api.post)

View File

@ -269,7 +269,7 @@ class TestDatasetMetadataApi:
class TestDatasetMetadataBuiltInFieldApi:
def test_get_built_in_fields(self, app):
def test_get_built_in_fields(self, app: Flask):
api = DatasetMetadataBuiltInFieldApi()
method = unwrap(api.get)

View File

@ -1,6 +1,8 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from flask import Flask
import controllers.console.explore.banner as banner_module
from models.enums import BannerStatus
@ -12,7 +14,7 @@ def unwrap(func):
class TestBannerApi:
def test_get_banners_with_requested_language(self, app):
def test_get_banners_with_requested_language(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)
@ -41,7 +43,7 @@ class TestBannerApi:
}
]
def test_get_banners_fallback_to_en_us(self, app):
def test_get_banners_fallback_to_en_us(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)
@ -76,7 +78,7 @@ class TestBannerApi:
}
]
def test_get_banners_default_language_en_us(self, app):
def test_get_banners_default_language_en_us(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import InternalServerError, NotFound
import controllers.console.explore.message as module
@ -54,7 +55,7 @@ def make_message():
class TestMessageListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -96,7 +97,7 @@ class TestMessageListApi:
with pytest.raises(NotChatAppError):
method(installed_app)
def test_conversation_not_exists(self, app):
def test_conversation_not_exists(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -118,7 +119,7 @@ class TestMessageListApi:
with pytest.raises(NotFound):
method(installed_app)
def test_first_message_not_exists(self, app):
def test_first_message_not_exists(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -142,7 +143,7 @@ class TestMessageListApi:
class TestMessageFeedbackApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
@ -161,7 +162,7 @@ class TestMessageFeedbackApi:
assert result["result"] == "success"
def test_message_not_exists(self, app):
def test_message_not_exists(self, app: Flask):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
@ -182,7 +183,7 @@ class TestMessageFeedbackApi:
class TestMessageMoreLikeThisApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -221,7 +222,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(NotCompletionAppError):
method(installed_app, "mid")
def test_more_like_this_disabled(self, app):
def test_more_like_this_disabled(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -243,7 +244,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(AppMoreLikeThisDisabledError):
method(installed_app, "mid")
def test_message_not_exists_more_like_this(self, app):
def test_message_not_exists_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -265,7 +266,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_more_like_this(self, app):
def test_provider_not_init_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -287,7 +288,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_more_like_this(self, app):
def test_quota_exceeded_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -309,7 +310,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_more_like_this(self, app):
def test_model_not_support_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -331,7 +332,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_more_like_this(self, app):
def test_invoke_error_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -353,7 +354,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_more_like_this(self, app):
def test_unexpected_error_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)

View File

@ -1,5 +1,7 @@
from unittest.mock import MagicMock, patch
from flask import Flask
import controllers.console.explore.recommended_app as module
from models.model import AppMode, IconType
@ -11,7 +13,7 @@ def unwrap(func):
class TestRecommendedAppListApi:
def test_get_with_language_param(self, app):
def test_get_with_language_param(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -31,7 +33,7 @@ class TestRecommendedAppListApi:
service_mock.assert_called_once_with("en-US")
assert result == result_data
def test_get_fallback_to_user_language(self, app):
def test_get_fallback_to_user_language(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -51,7 +53,7 @@ class TestRecommendedAppListApi:
service_mock.assert_called_once_with("fr-FR")
assert result == result_data
def test_get_fallback_to_default_language(self, app):
def test_get_fallback_to_default_language(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -73,7 +75,7 @@ class TestRecommendedAppListApi:
class TestRecommendedAppApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.RecommendedAppApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.saved_message as module
@ -42,7 +43,7 @@ def payload_patch():
class TestSavedMessageListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.SavedMessageListApi()
method = unwrap(api.get)

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import controllers.console.explore.trial as module
@ -88,7 +89,7 @@ def valid_parameters():
class TestTrialAppWorkflowRunApi:
def test_not_workflow_app(self, app):
def test_not_workflow_app(self, app: Flask):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -224,7 +225,7 @@ class TestTrialAppWorkflowRunApi:
class TestTrialChatApi:
def test_not_chat_app(self, app):
def test_not_chat_app(self, app: Flask):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -408,7 +409,7 @@ class TestTrialChatApi:
class TestTrialCompletionApi:
def test_not_completion_app(self, app):
def test_not_completion_app(self, app: Flask):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -560,7 +561,7 @@ class TestTrialCompletionApi:
class TestTrialMessageSuggestedQuestionApi:
def test_not_chat_app(self, app):
def test_not_chat_app(self, app: Flask):
api = module.TrialMessageSuggestedQuestionApi()
method = unwrap(api.get)
@ -952,7 +953,7 @@ class TestTrialAppWorkflowTaskStopApi:
class TestTrialSitApi:
def test_no_site(self, app):
def test_no_site(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
app_model = MagicMock()
@ -963,7 +964,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_archived_tenant(self, app):
def test_archived_tenant(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
@ -978,7 +979,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_success(self, app):
def test_success(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)

View File

@ -73,7 +73,7 @@ def payload_patch():
class TestTagListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = TagListApi()
method = unwrap(api.get)
@ -124,7 +124,7 @@ class TestTagListApi:
assert result["name"] == "test-tag"
assert result["binding_count"] == "0"
def test_post_forbidden(self, app, readonly_user, payload_patch):
def test_post_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagListApi()
method = unwrap(api.post)
@ -170,7 +170,7 @@ class TestTagUpdateDeleteApi:
assert status == 200
assert result["binding_count"] == "3"
def test_patch_forbidden(self, app, readonly_user, payload_patch):
def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
@ -231,7 +231,7 @@ class TestTagBindingCollectionApi:
assert status == 200
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
def test_create_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagBindingCollectionApi()
method = unwrap(api.post)
@ -275,7 +275,7 @@ class TestTagBindingRemoveApi:
assert status == 200
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
def test_remove_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagBindingRemoveApi()
method = unwrap(api.post)

View File

@ -82,7 +82,7 @@ def mock_file_service(mock_db):
class TestFileApiGet:
def test_get_upload_config(self, app):
def test_get_upload_config(self, app: Flask):
api = FileApi()
get_method = unwrap(api.get)
@ -290,7 +290,7 @@ class TestFilePreviewApi:
class TestFileSupportTypeApi:
def test_get_supported_types(self, app):
def test_get_supported_types(self, app: Flask):
api = FileSupportTypeApi()
get_method = unwrap(api.get)

View File

@ -58,7 +58,7 @@ class TestChangeEmailSend:
mock_get_change_data,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
@ -107,7 +107,7 @@ class TestChangeEmailSend:
mock_get_change_data,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
@ -155,7 +155,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
@ -214,7 +214,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
@ -267,7 +267,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
@ -316,7 +316,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
@ -366,7 +366,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
@ -418,7 +418,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
@ -471,7 +471,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
@ -547,7 +547,7 @@ class TestAccountServiceSendChangeEmailEmail:
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask):
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@ -563,7 +563,7 @@ class TestCheckEmailUnique:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask):
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@ -43,7 +43,7 @@ class TestMemberInviteEmailApi:
mock_current_account,
mock_invite_member,
mock_get_features,
app,
app: Flask,
):
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -41,7 +42,7 @@ def unwrap(func):
class TestAccountInitApi:
def test_init_success(self, app):
def test_init_success(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
@ -64,7 +65,7 @@ class TestAccountInitApi:
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
def test_init_already_initialized(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
@ -79,7 +80,7 @@ class TestAccountInitApi:
class TestAccountProfileApi:
def test_get_profile_success(self, app):
def test_get_profile_success(self, app: Flask):
api = AccountProfileApi()
method = unwrap(api.get)
@ -140,7 +141,7 @@ class TestAccountUpdateApis:
class TestAccountAvatarApiGet:
"""GET /account/avatar must not sign arbitrary upload_file IDs (IDOR)."""
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app):
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -172,7 +173,7 @@ class TestAccountAvatarApiGet:
assert result == {"avatar_url": "https://signed/example"}
sign_mock.assert_called_once_with(upload_file_id=file_id)
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app):
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -204,7 +205,7 @@ class TestAccountAvatarApiGet:
sign_mock.assert_not_called()
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app):
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -236,7 +237,7 @@ class TestAccountAvatarApiGet:
sign_mock.assert_not_called()
def test_get_avatar_https_pass_through_without_signing(self, app):
def test_get_avatar_https_pass_through_without_signing(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -263,7 +264,7 @@ class TestAccountAvatarApiGet:
class TestAccountPasswordApi:
def test_password_success(self, app):
def test_password_success(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
@ -292,7 +293,7 @@ class TestAccountPasswordApi:
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
def test_password_wrong_current(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
@ -317,7 +318,7 @@ class TestAccountPasswordApi:
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
def test_get_integrates(self, app: Flask):
api = AccountIntegrateApi()
method = unwrap(api.get)
@ -336,7 +337,7 @@ class TestAccountIntegrateApi:
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
def test_delete_verify_success(self, app: Flask):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
@ -358,7 +359,7 @@ class TestAccountDeleteApi:
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
def test_delete_invalid_code(self, app: Flask):
api = AccountDeleteApi()
method = unwrap(api.post)
@ -379,7 +380,7 @@ class TestAccountDeleteApi:
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
def test_check_email_code_invalid(self, app: Flask):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
@ -405,7 +406,7 @@ class TestChangeEmailApis:
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
def test_reset_email_already_used(self, app: Flask):
api = ChangeEmailResetApi()
method = unwrap(api.post)
@ -427,7 +428,7 @@ class TestChangeEmailApis:
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
def test_email_unique_success(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)
@ -448,7 +449,7 @@ class TestCheckEmailUniqueApi:
assert result["result"] == "success"
def test_email_in_freeze(self, app):
def test_email_in_freeze(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
@ -16,7 +17,7 @@ def unwrap(func):
class TestAgentProviderListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -39,7 +40,7 @@ class TestAgentProviderListApi:
assert result == providers
def test_get_empty_list(self, app):
def test_get_empty_list(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -61,7 +62,7 @@ class TestAgentProviderListApi:
assert result == []
def test_get_account_not_found(self, app):
def test_get_account_not_found(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -77,7 +78,7 @@ class TestAgentProviderListApi:
class TestAgentProviderApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
@ -101,7 +102,7 @@ class TestAgentProviderApi:
assert result == provider_data
def test_get_provider_not_found(self, app):
def test_get_provider_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
@ -124,7 +125,7 @@ class TestAgentProviderApi:
assert result is None
def test_get_account_not_found(self, app):
def test_get_account_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.workspace.endpoint import (
@ -39,7 +40,7 @@ def patch_current_account(user_and_tenant):
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCollectionApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -57,7 +58,7 @@ class TestEndpointCollectionApi:
assert result["success"] is True
def test_create_permission_denied(self, app):
def test_create_permission_denied(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -77,7 +78,7 @@ class TestEndpointCollectionApi:
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
def test_create_validation_error(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -96,7 +97,7 @@ class TestEndpointCollectionApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointCreateApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = DeprecatedEndpointCreateApi()
method = unwrap(api.post)
@ -117,7 +118,7 @@ class TestDeprecatedEndpointCreateApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
@ -130,7 +131,7 @@ class TestEndpointListApi:
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
def test_list_invalid_query(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
@ -143,7 +144,7 @@ class TestEndpointListApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
def test_list_for_plugin_success(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
@ -158,7 +159,7 @@ class TestEndpointListForSinglePluginApi:
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
def test_list_for_plugin_missing_param(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
@ -171,7 +172,7 @@ class TestEndpointListForSinglePluginApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointItemApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
@ -187,7 +188,7 @@ class TestEndpointItemApi:
assert result["success"] is True
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
def test_delete_service_failure(self, app):
def test_delete_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
@ -199,7 +200,7 @@ class TestEndpointItemApi:
assert result["success"] is False
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -226,7 +227,7 @@ class TestEndpointItemApi:
settings={"x": 1},
)
def test_update_validation_error(self, app):
def test_update_validation_error(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -238,7 +239,7 @@ class TestEndpointItemApi:
with pytest.raises(ValueError):
method(api, "e1")
def test_update_service_failure(self, app):
def test_update_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -258,7 +259,7 @@ class TestEndpointItemApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -272,7 +273,7 @@ class TestDeprecatedEndpointDeleteApi:
assert result["success"] is True
def test_delete_invalid_payload(self, app):
def test_delete_invalid_payload(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -282,7 +283,7 @@ class TestDeprecatedEndpointDeleteApi:
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
def test_delete_service_failure(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -299,7 +300,7 @@ class TestDeprecatedEndpointDeleteApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -317,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi:
assert result["success"] is True
def test_update_validation_error(self, app):
def test_update_validation_error(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -329,7 +330,7 @@ class TestDeprecatedEndpointUpdateApi:
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
def test_update_service_failure(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -380,7 +381,7 @@ class TestEndpointRouteMetadata:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
def test_enable_success(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -394,7 +395,7 @@ class TestEndpointEnableApi:
assert result["success"] is True
def test_enable_invalid_payload(self, app):
def test_enable_invalid_payload(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -404,7 +405,7 @@ class TestEndpointEnableApi:
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
def test_enable_service_failure(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -421,7 +422,7 @@ class TestEndpointEnableApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
def test_disable_success(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)
@ -435,7 +436,7 @@ class TestEndpointDisableApi:
assert result["success"] is True
def test_disable_invalid_payload(self, app):
def test_disable_invalid_payload(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
import services
@ -34,7 +35,7 @@ def unwrap(func):
class TestMemberListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = MemberListApi()
method = unwrap(api.get)
@ -59,7 +60,7 @@ class TestMemberListApi:
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
def test_get_no_tenant(self, app: Flask):
api = MemberListApi()
method = unwrap(api.get)
@ -74,7 +75,7 @@ class TestMemberListApi:
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
def test_invite_success(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -101,7 +102,7 @@ class TestMemberInviteEmailApi:
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
def test_invite_limit_exceeded(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -123,7 +124,7 @@ class TestMemberInviteEmailApi:
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
def test_invite_already_member(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -151,7 +152,7 @@ class TestMemberInviteEmailApi:
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
def test_invite_invalid_role(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -166,7 +167,7 @@ class TestMemberInviteEmailApi:
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
def test_invite_generic_exception(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -196,7 +197,7 @@ class TestMemberInviteEmailApi:
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
def test_cancel_success(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -216,7 +217,7 @@ class TestMemberCancelInviteApi:
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
def test_cancel_not_found(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -233,7 +234,7 @@ class TestMemberCancelInviteApi:
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
def test_cancel_cannot_operate_self(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -255,7 +256,7 @@ class TestMemberCancelInviteApi:
assert status == 400
def test_cancel_no_permission(self, app):
def test_cancel_no_permission(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -277,7 +278,7 @@ class TestMemberCancelInviteApi:
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
def test_cancel_member_not_in_tenant(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -301,7 +302,7 @@ class TestMemberCancelInviteApi:
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -324,7 +325,7 @@ class TestMemberUpdateRoleApi:
assert result["result"] == "success"
def test_update_invalid_role(self, app):
def test_update_invalid_role(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -335,7 +336,7 @@ class TestMemberUpdateRoleApi:
assert status == 400
def test_update_member_not_found(self, app):
def test_update_member_not_found(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -354,7 +355,7 @@ class TestMemberUpdateRoleApi:
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
@ -381,7 +382,7 @@ class TestDatasetOperatorMemberListApi:
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
def test_get_no_tenant(self, app: Flask):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
@ -396,7 +397,7 @@ class TestDatasetOperatorMemberListApi:
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
def test_send_success(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -419,7 +420,7 @@ class TestSendOwnerTransferEmailApi:
assert result["result"] == "success"
def test_send_ip_limit(self, app):
def test_send_ip_limit(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -433,7 +434,7 @@ class TestSendOwnerTransferEmailApi:
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
def test_send_not_owner(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -452,7 +453,7 @@ class TestSendOwnerTransferEmailApi:
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
def test_check_invalid_code(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -477,7 +478,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
def test_rate_limited(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -498,7 +499,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
def test_invalid_token(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -520,7 +521,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
def test_invalid_email(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -547,7 +548,7 @@ class TestOwnerTransferCheckApi:
class TestOwnerTransferApi:
def test_transfer_self(self, app):
def test_transfer_self(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
@ -564,7 +565,7 @@ class TestOwnerTransferApi:
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
def test_invalid_token(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
@ -582,7 +583,7 @@ class TestOwnerTransferApi:
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
def test_member_not_in_tenant(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic_core import ValidationError
from werkzeug.exceptions import Forbidden
@ -26,7 +27,7 @@ def unwrap(func):
class TestModelProviderListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ModelProviderListApi()
method = unwrap(api.get)
@ -47,7 +48,7 @@ class TestModelProviderListApi:
class TestModelProviderCredentialApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
@ -66,7 +67,7 @@ class TestModelProviderCredentialApi:
assert "credentials" in result
def test_get_invalid_uuid(self, app):
def test_get_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
@ -80,7 +81,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_post_create_success(self, app):
def test_post_create_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
@ -102,7 +103,7 @@ class TestModelProviderCredentialApi:
assert result["result"] == "success"
assert status == 201
def test_post_create_validation_error(self, app):
def test_post_create_validation_error(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
@ -122,7 +123,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValueError):
method(api, provider="openai")
def test_put_update_success(self, app):
def test_put_update_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
@ -143,7 +144,7 @@ class TestModelProviderCredentialApi:
assert result["result"] == "success"
def test_put_invalid_uuid(self, app):
def test_put_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
@ -159,7 +160,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.delete)
@ -183,7 +184,7 @@ class TestModelProviderCredentialApi:
class TestModelProviderCredentialSwitchApi:
def test_switch_success(self, app):
def test_switch_success(self, app: Flask):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
@ -204,7 +205,7 @@ class TestModelProviderCredentialSwitchApi:
assert result["result"] == "success"
def test_switch_invalid_uuid(self, app):
def test_switch_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
@ -222,7 +223,7 @@ class TestModelProviderCredentialSwitchApi:
class TestModelProviderValidateApi:
def test_validate_success(self, app):
def test_validate_success(self, app: Flask):
api = ModelProviderValidateApi()
method = unwrap(api.post)
@ -243,7 +244,7 @@ class TestModelProviderValidateApi:
assert result["result"] == "success"
def test_validate_failure(self, app):
def test_validate_failure(self, app: Flask):
api = ModelProviderValidateApi()
method = unwrap(api.post)
@ -266,7 +267,7 @@ class TestModelProviderValidateApi:
class TestModelProviderIconApi:
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = ModelProviderIconApi()
with (
@ -280,7 +281,7 @@ class TestModelProviderIconApi:
assert response.mimetype == "image/png"
def test_icon_not_found(self, app):
def test_icon_not_found(self, app: Flask):
api = ModelProviderIconApi()
with (
@ -295,7 +296,7 @@ class TestModelProviderIconApi:
class TestPreferredProviderTypeUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
@ -316,7 +317,7 @@ class TestPreferredProviderTypeUpdateApi:
assert result["result"] == "success"
def test_invalid_enum(self, app):
def test_invalid_enum(self, app: Flask):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
@ -334,7 +335,7 @@ class TestPreferredProviderTypeUpdateApi:
class TestModelProviderPaymentCheckoutUrlApi:
def test_checkout_success(self, app):
def test_checkout_success(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
@ -359,7 +360,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
assert "url" in result
def test_invalid_provider(self, app):
def test_invalid_provider(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
@ -367,7 +368,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
with pytest.raises(ValueError):
method(api, provider="openai")
def test_permission_denied(self, app):
def test_permission_denied(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)

View File

@ -72,7 +72,7 @@ class TestDefaultModelApi:
assert result["result"] == "success"
def test_get_returns_empty_when_no_default(self, app):
def test_get_returns_empty_when_no_default(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.get)
@ -154,7 +154,7 @@ class TestModelProviderModelApi:
assert status == 204
def test_get_models_returns_empty(self, app):
def test_get_models_returns_empty(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.get)
@ -224,7 +224,7 @@ class TestModelProviderModelCredentialApi:
assert status == 201
def test_get_empty_credentials(self, app):
def test_get_empty_credentials(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
@ -242,7 +242,7 @@ class TestModelProviderModelCredentialApi:
assert result["credentials"] == {}
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.delete)
@ -416,7 +416,7 @@ class TestParameterAndAvailableModels:
assert "data" in result
def test_empty_rules(self, app):
def test_empty_rules(self, app: Flask):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
@ -431,7 +431,7 @@ class TestParameterAndAvailableModels:
assert result["data"] == []
def test_no_models(self, app):
def test_no_models(self, app: Flask):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ import io
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
@ -61,7 +62,7 @@ def tenant():
class TestPluginListLatestVersionsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginListLatestVersionsApi()
method = unwrap(api.post)
@ -77,7 +78,7 @@ class TestPluginListLatestVersionsApi:
assert "versions" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginListLatestVersionsApi()
method = unwrap(api.post)
@ -95,7 +96,7 @@ class TestPluginListLatestVersionsApi:
class TestPluginDebuggingKeyApi:
def test_debugging_key_success(self, app):
def test_debugging_key_success(self, app: Flask):
api = PluginDebuggingKeyApi()
method = unwrap(api.get)
@ -108,7 +109,7 @@ class TestPluginDebuggingKeyApi:
assert result["key"] == "k"
def test_debugging_key_error(self, app):
def test_debugging_key_error(self, app: Flask):
api = PluginDebuggingKeyApi()
method = unwrap(api.get)
@ -125,7 +126,7 @@ class TestPluginDebuggingKeyApi:
class TestPluginListApi:
def test_plugin_list(self, app):
def test_plugin_list(self, app: Flask):
api = PluginListApi()
method = unwrap(api.get)
@ -142,7 +143,7 @@ class TestPluginListApi:
class TestPluginIconApi:
def test_plugin_icon(self, app):
def test_plugin_icon(self, app: Flask):
api = PluginIconApi()
method = unwrap(api.get)
@ -156,7 +157,7 @@ class TestPluginIconApi:
class TestPluginAssetApi:
def test_plugin_asset(self, app):
def test_plugin_asset(self, app: Flask):
api = PluginAssetApi()
method = unwrap(api.get)
@ -171,7 +172,7 @@ class TestPluginAssetApi:
class TestPluginUploadFromPkgApi:
def test_upload_pkg_success(self, app):
def test_upload_pkg_success(self, app: Flask):
api = PluginUploadFromPkgApi()
method = unwrap(api.post)
@ -188,7 +189,7 @@ class TestPluginUploadFromPkgApi:
assert result["ok"] is True
def test_upload_pkg_too_large(self, app):
def test_upload_pkg_too_large(self, app: Flask):
api = PluginUploadFromPkgApi()
method = unwrap(api.post)
@ -210,7 +211,7 @@ class TestPluginUploadFromPkgApi:
class TestPluginInstallFromPkgApi:
def test_install_from_pkg(self, app):
def test_install_from_pkg(self, app: Flask):
api = PluginInstallFromPkgApi()
method = unwrap(api.post)
@ -229,7 +230,7 @@ class TestPluginInstallFromPkgApi:
class TestPluginUninstallApi:
def test_uninstall(self, app):
def test_uninstall(self, app: Flask):
api = PluginUninstallApi()
method = unwrap(api.post)
@ -246,7 +247,7 @@ class TestPluginUninstallApi:
class TestPluginChangePermissionApi:
def test_change_permission_forbidden(self, app):
def test_change_permission_forbidden(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
@ -264,7 +265,7 @@ class TestPluginChangePermissionApi:
with pytest.raises(Forbidden):
method(api)
def test_change_permission_success(self, app):
def test_change_permission_success(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
@ -286,7 +287,7 @@ class TestPluginChangePermissionApi:
class TestPluginFetchPermissionApi:
def test_fetch_permission_default(self, app):
def test_fetch_permission_default(self, app: Flask):
api = PluginFetchPermissionApi()
method = unwrap(api.get)
@ -319,7 +320,7 @@ class TestPluginFetchDynamicSelectOptionsApi:
class TestPluginReadmeApi:
def test_fetch_readme(self, app):
def test_fetch_readme(self, app: Flask):
api = PluginReadmeApi()
method = unwrap(api.get)
@ -334,7 +335,7 @@ class TestPluginReadmeApi:
class TestPluginListInstallationsFromIdsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -352,7 +353,7 @@ class TestPluginListInstallationsFromIdsApi:
assert "plugins" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -371,7 +372,7 @@ class TestPluginListInstallationsFromIdsApi:
class TestPluginUploadFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -388,7 +389,7 @@ class TestPluginUploadFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -407,7 +408,7 @@ class TestPluginUploadFromGithubApi:
class TestPluginUploadFromBundleApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUploadFromBundleApi()
method = unwrap(api.post)
@ -430,7 +431,7 @@ class TestPluginUploadFromBundleApi:
assert result["ok"] is True
def test_too_large(self, app):
def test_too_large(self, app: Flask):
api = PluginUploadFromBundleApi()
method = unwrap(api.post)
@ -458,7 +459,7 @@ class TestPluginUploadFromBundleApi:
class TestPluginInstallFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginInstallFromGithubApi()
method = unwrap(api.post)
@ -478,7 +479,7 @@ class TestPluginInstallFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginInstallFromGithubApi()
method = unwrap(api.post)
@ -502,7 +503,7 @@ class TestPluginInstallFromGithubApi:
class TestPluginInstallFromMarketplaceApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginInstallFromMarketplaceApi()
method = unwrap(api.post)
@ -520,7 +521,7 @@ class TestPluginInstallFromMarketplaceApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginInstallFromMarketplaceApi()
method = unwrap(api.post)
@ -539,7 +540,7 @@ class TestPluginInstallFromMarketplaceApi:
class TestPluginFetchMarketplacePkgApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchMarketplacePkgApi()
method = unwrap(api.get)
@ -552,7 +553,7 @@ class TestPluginFetchMarketplacePkgApi:
assert "manifest" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchMarketplacePkgApi()
method = unwrap(api.get)
@ -569,7 +570,7 @@ class TestPluginFetchMarketplacePkgApi:
class TestPluginFetchManifestApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchManifestApi()
method = unwrap(api.get)
@ -585,7 +586,7 @@ class TestPluginFetchManifestApi:
assert "manifest" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchManifestApi()
method = unwrap(api.get)
@ -602,7 +603,7 @@ class TestPluginFetchManifestApi:
class TestPluginFetchInstallTasksApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchInstallTasksApi()
method = unwrap(api.get)
@ -615,7 +616,7 @@ class TestPluginFetchInstallTasksApi:
assert "tasks" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchInstallTasksApi()
method = unwrap(api.get)
@ -632,7 +633,7 @@ class TestPluginFetchInstallTasksApi:
class TestPluginFetchInstallTaskApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchInstallTaskApi()
method = unwrap(api.get)
@ -645,7 +646,7 @@ class TestPluginFetchInstallTaskApi:
assert "task" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchInstallTaskApi()
method = unwrap(api.get)
@ -662,7 +663,7 @@ class TestPluginFetchInstallTaskApi:
class TestPluginDeleteInstallTaskApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteInstallTaskApi()
method = unwrap(api.post)
@ -675,7 +676,7 @@ class TestPluginDeleteInstallTaskApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteInstallTaskApi()
method = unwrap(api.post)
@ -692,7 +693,7 @@ class TestPluginDeleteInstallTaskApi:
class TestPluginDeleteAllInstallTaskItemsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteAllInstallTaskItemsApi()
method = unwrap(api.post)
@ -707,7 +708,7 @@ class TestPluginDeleteAllInstallTaskItemsApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteAllInstallTaskItemsApi()
method = unwrap(api.post)
@ -724,7 +725,7 @@ class TestPluginDeleteAllInstallTaskItemsApi:
class TestPluginDeleteInstallTaskItemApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteInstallTaskItemApi()
method = unwrap(api.post)
@ -737,7 +738,7 @@ class TestPluginDeleteInstallTaskItemApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteInstallTaskItemApi()
method = unwrap(api.post)
@ -754,7 +755,7 @@ class TestPluginDeleteInstallTaskItemApi:
class TestPluginUpgradeFromMarketplaceApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -775,7 +776,7 @@ class TestPluginUpgradeFromMarketplaceApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -797,7 +798,7 @@ class TestPluginUpgradeFromMarketplaceApi:
class TestPluginUpgradeFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUpgradeFromGithubApi()
method = unwrap(api.post)
@ -821,7 +822,7 @@ class TestPluginUpgradeFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUpgradeFromGithubApi()
method = unwrap(api.post)
@ -846,7 +847,7 @@ class TestPluginUpgradeFromGithubApi:
class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
@ -873,7 +874,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
assert result["options"] == [1]
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
@ -901,7 +902,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
class TestPluginChangePreferencesApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginChangePreferencesApi()
method = unwrap(api.post)
@ -931,7 +932,7 @@ class TestPluginChangePreferencesApi:
assert result["success"] is True
def test_permission_fail(self, app):
def test_permission_fail(self, app: Flask):
api = PluginChangePreferencesApi()
method = unwrap(api.post)
@ -962,7 +963,7 @@ class TestPluginChangePreferencesApi:
class TestPluginFetchPreferencesApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchPreferencesApi()
method = unwrap(api.get)
@ -996,7 +997,7 @@ class TestPluginFetchPreferencesApi:
class TestPluginAutoUpgradeExcludePluginApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginAutoUpgradeExcludePluginApi()
method = unwrap(api.post)
@ -1011,7 +1012,7 @@ class TestPluginAutoUpgradeExcludePluginApi:
assert result["success"] is True
def test_fail(self, app):
def test_fail(self, app: Flask):
api = PluginAutoUpgradeExcludePluginApi()
method = unwrap(api.post)

View File

@ -2,6 +2,7 @@ from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Unauthorized
@ -37,7 +38,7 @@ def unwrap(func):
class TestTenantListApi:
def test_get_success_saas_path(self, app):
def test_get_success_saas_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -85,7 +86,7 @@ class TestTenantListApi:
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app: Flask):
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
@ -140,7 +141,7 @@ class TestTenantListApi:
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_called_once_with("t2")
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask):
"""Test fallback to FeatureService when bulk billing returns empty result.
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
@ -197,7 +198,7 @@ class TestTenantListApi:
assert get_features_mock.call_count == 2
logger_warning_mock.assert_called_once()
def test_get_billing_disabled_community_path(self, app):
def test_get_billing_disabled_community_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -236,7 +237,7 @@ class TestTenantListApi:
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
def test_get_enterprise_only_skips_feature_service(self, app):
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -276,7 +277,7 @@ class TestTenantListApi:
assert result["workspaces"][1]["current"] is True
get_features_mock.assert_not_called()
def test_get_enterprise_only_with_empty_tenants(self, app):
def test_get_enterprise_only_with_empty_tenants(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -302,7 +303,7 @@ class TestTenantListApi:
class TestWorkspaceListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
@ -324,7 +325,7 @@ class TestWorkspaceListApi:
assert result["total"] == 1
assert result["has_more"] is False
def test_get_has_next_true(self, app):
def test_get_has_next_true(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
@ -355,7 +356,7 @@ class TestWorkspaceListApi:
class TestTenantApi:
def test_post_active_tenant(self, app):
def test_post_active_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -375,7 +376,7 @@ class TestTenantApi:
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app):
def test_post_archived_with_switch(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -397,7 +398,7 @@ class TestTenantApi:
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app):
def test_post_archived_no_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -411,7 +412,7 @@ class TestTenantApi:
with pytest.raises(Unauthorized):
method(api)
def test_post_info_path(self, app):
def test_post_info_path(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -454,7 +455,7 @@ class TestTenantInfoResponse:
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
def test_switch_success(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -477,7 +478,7 @@ class TestSwitchWorkspaceApi:
assert result["result"] == "success"
def test_switch_not_linked(self, app):
def test_switch_not_linked(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -493,7 +494,7 @@ class TestSwitchWorkspaceApi:
with pytest.raises(AccountNotLinkTenantError):
method(api)
def test_switch_tenant_not_found(self, app):
def test_switch_tenant_not_found(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -515,7 +516,7 @@ class TestSwitchWorkspaceApi:
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
@ -538,7 +539,7 @@ class TestCustomConfigWorkspaceApi:
assert result["result"] == "success"
def test_logo_fallback(self, app):
def test_logo_fallback(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
@ -569,7 +570,7 @@ class TestCustomConfigWorkspaceApi:
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app):
def test_no_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -582,7 +583,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(NoFileUploadedError):
method(api)
def test_too_many_files(self, app):
def test_too_many_files(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -601,7 +602,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(TooManyFilesError):
method(api)
def test_invalid_extension(self, app):
def test_invalid_extension(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -616,7 +617,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(UnsupportedFileTypeError):
method(api)
def test_upload_success(self, app):
def test_upload_success(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -648,7 +649,7 @@ class TestWebappLogoWorkspaceApi:
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app):
def test_filename_missing(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -672,7 +673,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(FilenameNotExistsError):
method(api)
def test_file_too_large(self, app):
def test_file_too_large(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -701,7 +702,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(FileTooLargeError):
method(api)
def test_service_unsupported_file(self, app):
def test_service_unsupported_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -732,7 +733,7 @@ class TestWebappLogoWorkspaceApi:
class TestWorkspaceInfoApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
@ -756,7 +757,7 @@ class TestWorkspaceInfoApi:
assert result["result"] == "success"
def test_no_current_tenant(self, app):
def test_no_current_tenant(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
@ -774,7 +775,7 @@ class TestWorkspaceInfoApi:
class TestWorkspacePermissionApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)
@ -799,7 +800,7 @@ class TestWorkspacePermissionApi:
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app):
def test_no_current_tenant(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)

View File

@ -41,7 +41,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_chat_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving parameters for a chat app."""
# Arrange
@ -91,7 +91,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_workflow_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving parameters for a workflow app."""
# Arrange
@ -136,7 +136,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_chat_config_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test that AppUnavailableError is raised when chat app has no config."""
# Arrange
@ -174,7 +174,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_workflow_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
# Arrange
@ -234,7 +234,14 @@ class TestAppMetaApi:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.app.app.AppService")
def test_get_app_meta(
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self,
mock_app_service,
mock_db,
mock_validate_token,
mock_current_app,
mock_user_logged_in,
app: Flask,
mock_app_model,
):
"""Test retrieving app metadata via AppService."""
# Arrange
@ -310,7 +317,7 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving basic app information."""
mock_current_app.login_manager = Mock()
@ -402,7 +409,9 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
def test_get_app_info_with_no_tags(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask
):
"""Test retrieving app info when app has no tags."""
# Arrange
mock_current_app.login_manager = Mock()
@ -453,7 +462,7 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_returns_correct_mode(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, app_mode
):
"""Test that all app modes are correctly returned."""
# Arrange

View File

@ -13,6 +13,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
@ -190,7 +191,7 @@ class TestAudioServiceMockedBehavior:
class TestAudioApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = AudioApi()
handler = _unwrap(api.post)
@ -216,7 +217,7 @@ class TestAudioApi:
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = AudioApi()
handler = _unwrap(api.post)
@ -227,7 +228,7 @@ class TestAudioApi:
with pytest.raises(expected):
handler(api, app_model=app_model, end_user=end_user)
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_unhandled_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
)
@ -242,7 +243,7 @@ class TestAudioApi:
class TestTextApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = TextApi()
@ -259,7 +260,7 @@ class TestTextApi:
assert response == {"audio": "ok"}
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
)

View File

@ -16,6 +16,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
@ -295,7 +296,7 @@ class TestCompletionControllerLogic:
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask):
"""Test CompletionApi.post success path."""
from controllers.service_api.app.completion import CompletionApi
@ -320,7 +321,7 @@ class TestCompletionControllerLogic:
mock_generate_service.generate.assert_called_once()
@patch("controllers.service_api.app.completion.service_api_ns")
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask):
"""Test CompletionApi.post with wrong app mode."""
from controllers.service_api.app.completion import CompletionApi
@ -334,7 +335,7 @@ class TestCompletionControllerLogic:
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask):
"""Test ChatApi.post success path."""
from controllers.service_api.app.completion import ChatApi
@ -355,7 +356,7 @@ class TestCompletionControllerLogic:
assert response == {"text": "compacted"}
@patch("controllers.service_api.app.completion.service_api_ns")
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask):
"""Test ChatApi.post with wrong app mode."""
from controllers.service_api.app.completion import ChatApi
@ -368,7 +369,7 @@ class TestCompletionControllerLogic:
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.AppTaskService")
def test_completion_stop_api_success(self, mock_task_service, app):
def test_completion_stop_api_success(self, mock_task_service, app: Flask):
"""Test CompletionStopApi.post success."""
from controllers.service_api.app.completion import CompletionStopApi
@ -385,7 +386,7 @@ class TestCompletionControllerLogic:
mock_task_service.stop_task.assert_called_once()
@patch("controllers.service_api.app.completion.AppTaskService")
def test_chat_stop_api_success(self, mock_task_service, app):
def test_chat_stop_api_success(self, mock_task_service, app: Flask):
"""Test ChatStopApi.post success."""
from controllers.service_api.app.completion import ChatStopApi

View File

@ -20,6 +20,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
import services
@ -504,7 +505,7 @@ class TestConversationApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_list_last_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _BeginStub:
def __enter__(self):
return SimpleNamespace()
@ -552,7 +553,7 @@ class TestConversationDetailApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"delete",
@ -570,7 +571,7 @@ class TestConversationDetailApiController:
class TestConversationRenameApiController:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"rename",
@ -602,7 +603,7 @@ class TestConversationVariablesApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
@ -621,7 +622,7 @@ class TestConversationVariablesApiController:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
@ -661,7 +662,7 @@ class TestConversationVariablesApiController:
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_type_mismatch(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
@ -687,7 +688,7 @@ class TestConversationVariableDetailApiController:
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
@ -713,7 +714,7 @@ class TestConversationVariableDetailApiController:
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,

View File

@ -16,6 +16,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from controllers.common.errors import (
FilenameNotExistsError,
@ -282,7 +283,7 @@ class TestFileApiPost:
assert status == 201
mock_file_svc_cls.return_value.upload_file.assert_called_once()
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
def test_upload_no_file(self, app: Flask, mock_app_model, mock_end_user):
"""Test NoFileUploadedError when no file in request."""
from controllers.service_api.app.file import FileApi
@ -296,7 +297,7 @@ class TestFileApiPost:
with pytest.raises(NoFileUploadedError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user):
"""Test TooManyFilesError when multiple files uploaded."""
from io import BytesIO
@ -317,7 +318,7 @@ class TestFileApiPost:
with pytest.raises(TooManyFilesError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user):
"""Test UnsupportedFileTypeError when file has no mimetype."""
from io import BytesIO

View File

@ -11,6 +11,7 @@ from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
from flask import Flask
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
@ -281,7 +282,7 @@ class TestHitlServiceApi:
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
self, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",

View File

@ -19,6 +19,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api.app.error import NotChatAppError
@ -390,7 +391,7 @@ class TestMessageListApi:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_conversation_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
@ -409,7 +410,7 @@ class TestMessageListApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_first_message_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
@ -430,7 +431,7 @@ class TestMessageListApi:
class TestMessageFeedbackApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"create_feedback",
@ -452,7 +453,7 @@ class TestMessageFeedbackApi:
class TestAppGetFeedbacksApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
api = AppGetFeedbacksApi()
@ -476,7 +477,7 @@ class TestMessageSuggestedApi:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -492,7 +493,7 @@ class TestMessageSuggestedApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_disabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -508,7 +509,7 @@ class TestMessageSuggestedApi:
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_internal_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -524,7 +525,7 @@ class TestMessageSuggestedApi:
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",

View File

@ -20,6 +20,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.service_api.app.error import NotWorkflowAppError
@ -366,7 +367,7 @@ class TestWorkflowRunRepository:
class TestWorkflowRunDetailApi:
def test_not_workflow_app(self, app) -> None:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -397,7 +398,7 @@ class TestWorkflowRunDetailApi:
class TestWorkflowRunApi:
def test_not_workflow_app(self, app) -> None:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -407,7 +408,7 @@ class TestWorkflowRunApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_rate_limit(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -425,7 +426,7 @@ class TestWorkflowRunApi:
class TestWorkflowRunByIdApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -441,7 +442,7 @@ class TestWorkflowRunByIdApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_draft_workflow(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -459,7 +460,7 @@ class TestWorkflowRunByIdApi:
class TestWorkflowTaskStopApi:
def test_wrong_mode(self, app) -> None:
def test_wrong_mode(self, app: Flask) -> None:
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -469,7 +470,7 @@ class TestWorkflowTaskStopApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
send_mock = Mock()
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
@ -489,7 +490,7 @@ class TestWorkflowTaskStopApi:
class TestWorkflowAppLogApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _BeginStub:
def __enter__(self):
return SimpleNamespace()
@ -557,7 +558,7 @@ class TestWorkflowRunDetailApiGet:
self,
mock_db,
mock_repo_factory,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow run detail retrieval."""
@ -579,7 +580,7 @@ class TestWorkflowRunDetailApiGet:
assert result["status"] == "succeeded"
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
def test_get_workflow_run_wrong_app_mode(self, mock_db, app: Flask):
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
from controllers.service_api.app.workflow import WorkflowRunDetailApi
@ -604,7 +605,7 @@ class TestWorkflowTaskStopApiPost:
self,
mock_queue_mgr,
mock_graph_mgr,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow task stop."""
@ -624,7 +625,7 @@ class TestWorkflowTaskStopApiPost:
mock_graph_mgr.assert_called_once()
mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1")
def test_stop_workflow_task_wrong_app_mode(self, app):
def test_stop_workflow_task_wrong_app_mode(self, app: Flask):
"""Test NotWorkflowAppError when app mode is not workflow."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
@ -649,7 +650,7 @@ class TestWorkflowAppLogApiGet:
self,
mock_db,
mock_wf_svc_cls,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow log retrieval."""

View File

@ -9,6 +9,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.app.error import NotWorkflowAppError
@ -41,7 +42,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
@ -52,7 +53,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_permission_denied(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -70,7 +71,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_finished_run_returns_sse(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -103,7 +104,7 @@ class TestWorkflowEventsApi:
assert payload["task_id"] == "run-1"
assert payload["event"] == "workflow_finished"
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_running_run_streams_events(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -135,7 +136,7 @@ class TestWorkflowEventsApi:
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_running_run_with_snapshot(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",

View File

@ -23,6 +23,7 @@ from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden, NotFound
@ -373,7 +374,7 @@ class TestDatasourcePluginsApiGet:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_success(self, mock_svc_cls, mock_db, app):
def test_get_plugins_success(self, mock_svc_cls, mock_db, app: Flask):
"""Test successful retrieval of datasource plugins."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
@ -396,7 +397,7 @@ class TestDatasourcePluginsApiGet:
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_get_plugins_not_found(self, mock_db, app):
def test_get_plugins_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -407,7 +408,7 @@ class TestDatasourcePluginsApiGet:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app):
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app: Flask):
"""Test empty plugin list."""
mock_db.session.scalar.return_value = Mock()
mock_svc_instance = Mock()
@ -439,7 +440,7 @@ class TestDatasourceNodeRunApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app):
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app: Flask):
"""Test successful datasource node run."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
@ -473,7 +474,7 @@ class TestDatasourceNodeRunApiPost:
mock_svc_instance.run_datasource_workflow_node.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
def test_post_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -488,7 +489,7 @@ class TestDatasourceNodeRunApiPost:
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app):
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app: Flask):
"""Test AssertionError when current_user is not an Account instance."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
@ -549,7 +550,7 @@ class TestPipelineRunApiPost:
mock_gen_svc.generate.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
def test_post_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -561,7 +562,7 @@ class TestPipelineRunApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app):
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app: Flask):
"""Test Forbidden when current_user is not an Account."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
@ -585,7 +586,7 @@ class TestFileUploadApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app):
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app: Flask):
"""Test successful file upload."""
mock_current_user.__bool__ = Mock(return_value=True)
@ -621,7 +622,7 @@ class TestFileUploadApiPost:
assert response["name"] == "doc.pdf"
assert response["extension"] == "pdf"
def test_upload_no_file(self, app):
def test_upload_no_file(self, app: Flask):
"""Test error when no file is uploaded."""
with app.test_request_context(
"/datasets/pipeline/file-upload",

View File

@ -18,6 +18,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.segment import (
@ -782,7 +783,7 @@ class TestSegmentApiGet:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -893,7 +894,7 @@ class TestSegmentApiPost:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -946,7 +947,7 @@ class TestSegmentApiPost:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -989,7 +990,7 @@ class TestSegmentApiPost:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1041,7 +1042,7 @@ class TestDatasetSegmentApiDelete:
mock_doc_svc,
mock_dataset_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1086,7 +1087,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_doc_svc,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1282,7 +1283,7 @@ class TestDatasetSegmentApiUpdate:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1322,7 +1323,7 @@ class TestDatasetSegmentApiUpdate:
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1374,7 +1375,7 @@ class TestDatasetSegmentApiGetSingle:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1421,7 +1422,7 @@ class TestDatasetSegmentApiGetSingle:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1460,7 +1461,7 @@ class TestDatasetSegmentApiGetSingle:
self,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1491,7 +1492,7 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1526,7 +1527,7 @@ class TestDatasetSegmentApiGetSingle:
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1570,7 +1571,7 @@ class TestChildChunkApiGet:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1609,7 +1610,7 @@ class TestChildChunkApiGet:
self,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1638,7 +1639,7 @@ class TestChildChunkApiGet:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1670,7 +1671,7 @@ class TestChildChunkApiGet:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1729,7 +1730,7 @@ class TestChildChunkApiPost:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1771,7 +1772,7 @@ class TestChildChunkApiPost:
mock_feature_svc,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1809,7 +1810,7 @@ class TestChildChunkApiPost:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1863,7 +1864,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1913,7 +1914,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1954,7 +1955,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1994,7 +1995,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):

View File

@ -19,6 +19,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.metadata import (
@ -76,7 +77,7 @@ class TestDatasetMetadataCreatePost:
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -106,7 +107,7 @@ class TestDatasetMetadataCreatePost:
def test_create_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -136,7 +137,7 @@ class TestDatasetMetadataCreateGet:
self,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -160,7 +161,7 @@ class TestDatasetMetadataCreateGet:
def test_get_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -201,7 +202,7 @@ class TestDatasetMetadataServiceApiPatch:
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -232,7 +233,7 @@ class TestDatasetMetadataServiceApiPatch:
def test_update_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -273,7 +274,7 @@ class TestDatasetMetadataServiceApiDelete:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -302,7 +303,7 @@ class TestDatasetMetadataServiceApiDelete:
def test_delete_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -336,7 +337,7 @@ class TestDatasetMetadataBuiltInFieldGet:
def test_get_built_in_fields_success(
self,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -382,7 +383,7 @@ class TestDatasetMetadataBuiltInFieldAction:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -414,7 +415,7 @@ class TestDatasetMetadataBuiltInFieldAction:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -441,7 +442,7 @@ class TestDatasetMetadataBuiltInFieldAction:
def test_action_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -485,7 +486,7 @@ class TestDocumentMetadataEditPost:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -513,7 +514,7 @@ class TestDocumentMetadataEditPost:
def test_update_documents_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):

View File

@ -5,6 +5,7 @@ Unit tests for Service API Index endpoint
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.service_api.index import IndexApi
@ -13,7 +14,7 @@ class TestIndexApi:
"""Test suite for IndexApi resource."""
@patch("controllers.service_api.index.dify_config", autospec=True)
def test_get_returns_api_info(self, mock_config, app):
def test_get_returns_api_info(self, mock_config, app: Flask):
"""Test that GET returns API metadata with correct structure."""
# Arrange
mock_config.project.version = "1.0.0-test"
@ -32,7 +33,7 @@ class TestIndexApi:
assert response["api_version"] == "v1"
assert response["server_version"] == "1.0.0-test"
def test_get_response_has_required_fields(self, app):
def test_get_response_has_required_fields(self, app: Flask):
"""Test that response contains all required fields."""
# Arrange
mock_config = MagicMock()

View File

@ -39,7 +39,7 @@ class TestValidateAndGetApiToken:
app.config["TESTING"] = True
return app
def test_missing_authorization_header(self, app):
def test_missing_authorization_header(self, app: Flask):
"""Test that Unauthorized is raised when Authorization header is missing."""
# Arrange
with app.test_request_context("/", method="GET"):
@ -50,7 +50,7 @@ class TestValidateAndGetApiToken:
validate_and_get_api_token("app")
assert "Authorization header must be provided" in str(exc_info.value)
def test_invalid_auth_scheme(self, app):
def test_invalid_auth_scheme(self, app: Flask):
"""Test that Unauthorized is raised when auth scheme is not Bearer."""
# Arrange
with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}):
@ -62,7 +62,7 @@ class TestValidateAndGetApiToken:
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask):
"""Test that valid token returns the ApiToken object."""
# Arrange
mock_api_token = Mock(spec=ApiToken)
@ -84,7 +84,7 @@ class TestValidateAndGetApiToken:
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask):
"""Test that invalid token raises Unauthorized."""
# Arrange
from werkzeug.exceptions import Unauthorized
@ -161,7 +161,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app no longer exists."""
# Arrange
mock_api_token = Mock()
@ -182,7 +182,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app status is abnormal."""
# Arrange
mock_api_token = Mock()
@ -205,7 +205,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app API is disabled."""
# Arrange
mock_api_token = Mock()
@ -240,7 +240,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app):
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that request is allowed when under resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -264,7 +264,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app):
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that Forbidden is raised when at resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -287,7 +287,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app):
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that request is allowed when billing is disabled."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -320,7 +320,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app):
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that add_segment is rejected in SANDBOX plan."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -342,7 +342,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app):
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that non-add_segment operations are allowed in SANDBOX."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -376,7 +376,7 @@ class TestCloudEditionBillingRateLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app):
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app: Flask):
"""Test that request is allowed when within rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -406,7 +406,7 @@ class TestCloudEditionBillingRateLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
@patch("controllers.service_api.wraps.db")
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app):
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app: Flask):
"""Test that Forbidden is raised when over rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -445,7 +445,7 @@ class TestValidateDatasetToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.current_app")
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app):
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask):
"""Test that valid dataset token allows access."""
# Arrange
# Use standard Mock for login_manager
@ -487,7 +487,7 @@ class TestValidateDatasetToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app):
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app: Flask):
"""Test that NotFound is raised when dataset doesn't exist."""
# Arrange
mock_api_token = Mock()

View File

@ -13,7 +13,7 @@ class TestExternalDataFetch:
app = Flask(__name__)
return app
def test_fetch_success(self, app):
def test_fetch_success(self, app: Flask):
with app.app_context():
fetcher = ExternalDataFetch()
@ -79,7 +79,7 @@ class TestExternalDataFetch:
assert result_inputs == inputs
assert result_inputs is not inputs # Should be a copy
def test_fetch_with_none_variable(self, app):
def test_fetch_with_none_variable(self, app: Flask):
with app.app_context():
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
@ -95,7 +95,7 @@ class TestExternalDataFetch:
assert "var1" not in result_inputs
assert result_inputs == {"in": "val"}
def test_query_external_data_tool(self, app):
def test_query_external_data_tool(self, app: Flask):
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})

View File

@ -1,79 +0,0 @@
from unittest.mock import MagicMock, patch
from models.account import TenantPluginPermission
MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock()
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=False)
session.begin.return_value.__enter__ = MagicMock(return_value=session)
session.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_factory = MagicMock()
mock_factory.create_session.return_value = session
patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, session
class TestGetPermission:
def test_returns_permission_when_found(self):
p1, session = _patched_session()
permission = MagicMock()
session.scalar.return_value = permission
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is permission
def test_returns_none_when_not_found(self):
p1, session = _patched_session()
session.scalar.return_value = None
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is None
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, session = _patched_session()
session.scalar.return_value = None
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
)
assert result is True
session.begin.assert_called_once()
session.add.assert_called_once()
def test_updates_existing_permission(self):
p1, session = _patched_session()
existing = MagicMock()
session.scalar.return_value = existing
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
)
assert result is True
session.begin.assert_called_once()
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.add.assert_not_called()

View File

@ -315,20 +315,12 @@
"count": 4
}
},
"web/app/components/app/configuration/config-var/config-modal/type-select.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/app/configuration/config-var/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/app/configuration/config-var/select-var-type.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
@ -363,9 +355,6 @@
}
},
"web/app/components/app/configuration/config/assistant-type-picker/index.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
@ -401,11 +390,6 @@
"count": 1
}
},
"web/app/components/app/configuration/config/automatic/version-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx": {
"no-restricted-imports": {
"count": 1
@ -774,11 +758,6 @@
"count": 1
}
},
"web/app/components/base/chat/chat-with-history/sidebar/rename-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/base/chat/chat/answer/agent-content.tsx": {
"style/multiline-ternary": {
"count": 2
@ -800,11 +779,6 @@
"count": 1
}
},
"web/app/components/base/chat/chat/answer/operation.tsx": {
"no-restricted-imports": {
"count": 2
}
},
"web/app/components/base/chat/chat/answer/workflow-process.tsx": {
"react/set-state-in-effect": {
"count": 1
@ -1055,14 +1029,6 @@
"count": 3
}
},
"web/app/components/base/form/components/base/base-field.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 3
}
},
"web/app/components/base/form/components/base/base-form.tsx": {
"ts/no-explicit-any": {
"count": 6
@ -1589,14 +1555,6 @@
"count": 1
}
},
"web/app/components/base/modal/modal.stories.tsx": {
"no-console": {
"count": 4
},
"react/set-state-in-effect": {
"count": 1
}
},
"web/app/components/base/new-audio-button/index.tsx": {
"ts/no-explicit-any": {
"count": 1
@ -2116,11 +2074,6 @@
"count": 1
}
},
"web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx": {
"ts/no-explicit-any": {
"count": 1
@ -2389,11 +2342,6 @@
"count": 1
}
},
"web/app/components/datasets/settings/index-method/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/develop/code.tsx": {
"ts/no-empty-object-type": {
"count": 1
@ -2579,11 +2527,8 @@
"erasable-syntax-only/enums": {
"count": 1
},
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 3
"count": 2
}
},
"web/app/components/header/account-setting/model-provider-page/declarations.ts": {
@ -2612,11 +2557,6 @@
"count": 1
}
},
"web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts": {
"no-barrel-files/no-barrel-files": {
"count": 6
@ -2912,44 +2852,11 @@
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts": {
"erasable-syntax-only/enums": {
"count": 2
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": {
"no-barrel-files/no-barrel-files": {
"count": 3
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts": {
"erasable-syntax-only/enums": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx": {
"erasable-syntax-only/enums": {
"count": 1
},
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": {
"no-barrel-files/no-barrel-files": {
"count": 2
@ -2978,11 +2885,6 @@
"count": 7
}
},
"web/app/components/plugins/plugin-detail-panel/tool-selector/components/schema-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-item.tsx": {
"no-restricted-imports": {
"count": 1
@ -2998,11 +2900,6 @@
"count": 5
}
},
"web/app/components/plugins/plugin-item/action.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-item/index.tsx": {
"no-restricted-imports": {
"count": 1
@ -3681,11 +3578,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/error-handle/types.ts": {
"erasable-syntax-only/enums": {
"count": 1
@ -3782,11 +3674,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/variable/var-type-picker.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts": {
"react/no-unnecessary-use-prefix": {
"count": 2
@ -4106,16 +3993,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/if-else/components/condition-number-input.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/if-else/default.ts": {
"ts/no-explicit-any": {
"count": 1
@ -4151,11 +4028,6 @@
"count": 6
}
},
"web/app/components/workflow/nodes/knowledge-base/components/chunk-structure/selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/hooks.tsx": {
"ts/no-explicit-any": {
"count": 4
@ -4199,11 +4071,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/metadata-filter-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-retrieval/default.ts": {
"ts/no-explicit-any": {
"count": 1
@ -4285,11 +4152,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/type-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts": {
"ts/no-explicit-any": {
"count": 1
@ -4341,16 +4203,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/loop/components/condition-list/condition-operator.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/loop/components/condition-number-input.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx": {
"ts/no-explicit-any": {
"count": 3
@ -4522,9 +4374,6 @@
}
},
"web/app/components/workflow/nodes/tool/components/tool-form/item.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}

View File

@ -4,6 +4,17 @@ import antfu, { GLOB_MARKDOWN } from '@antfu/eslint-config'
import md from 'eslint-markdown'
import markdownPreferences from 'eslint-plugin-markdown-preferences'
const GENERATED_IGNORES = [
'**/storybook-static/',
'**/.next/',
'web/next/',
'web/next-env.d.ts',
'**/dist/',
'**/coverage/',
'e2e/.auth/',
'e2e/cucumber-report/',
]
export default antfu(
{
ignores: original => [
@ -15,6 +26,7 @@ export default antfu(
'!package.json',
'!pnpm-workspace.yaml',
'!vite.config.ts',
...GENERATED_IGNORES,
...original,
],
typescript: {

View File

@ -2,7 +2,7 @@
"name": "dify",
"type": "module",
"private": true,
"packageManager": "pnpm@11.0.0",
"packageManager": "pnpm@11.0.6",
"engines": {
"node": "^22.22.1"
},

View File

@ -88,9 +88,9 @@ Every overlay primitive uses a single, shared z-index. Do **not** override it at
| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop |
| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. |
Rationale: during Dify's migration from legacy `portal-to-follow-elem` / `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins.
Rationale: during Dify's migration from legacy `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins.
See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history and the remaining legacy allowlist. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`.
See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`.
### Rules

2455
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ saveExact: true
catalogMode: prefer
dedupeDirectDeps: true
engineStrict: true
minimumReleaseAge: 1440
minimumReleaseAge: 0
optimisticRepeatInstall: true
verifyDepsBeforeRun: install
resolutionMode: time-based
@ -54,8 +54,8 @@ overrides:
yaml@>=2.0.0 <2.8.3: 2.8.3
yauzl@<3.2.1: 3.2.1
catalog:
'@amplitude/analytics-browser': 2.42.0
'@amplitude/plugin-session-replay-browser': 1.28.1
'@amplitude/analytics-browser': 2.42.1
'@amplitude/plugin-session-replay-browser': 1.29.0
'@antfu/eslint-config': 8.2.0
'@base-ui/react': 1.4.1
'@chromatic-com/storybook': 5.1.2
@ -65,11 +65,11 @@ catalog:
'@eslint-react/eslint-plugin': 3.0.0
'@eslint/js': 10.0.1
'@floating-ui/react': 0.27.19
'@formatjs/intl-localematcher': 0.8.4
'@formatjs/intl-localematcher': 0.8.6
'@headlessui/react': 2.2.10
'@heroicons/react': 2.2.0
'@hey-api/openapi-ts': 0.97.0
'@hono/node-server': 2.0.0
'@hey-api/openapi-ts': 0.97.1
'@hono/node-server': 2.0.1
'@iconify-json/heroicons': 1.2.3
'@iconify-json/ri': 1.2.10
'@lexical/code': 0.44.0
@ -85,42 +85,42 @@ catalog:
'@monaco-editor/react': 4.7.0
'@next/eslint-plugin-next': 16.2.4
'@next/mdx': 16.2.4
'@orpc/client': 1.14.0
'@orpc/contract': 1.14.0
'@orpc/openapi-client': 1.14.0
'@orpc/tanstack-query': 1.14.0
'@orpc/client': 1.14.1
'@orpc/contract': 1.14.1
'@orpc/openapi-client': 1.14.1
'@orpc/tanstack-query': 1.14.1
'@playwright/test': 1.59.1
'@remixicon/react': 4.9.0
'@rgrove/parse-xml': 4.2.0
'@sentry/react': 10.50.0
'@storybook/addon-docs': 10.3.5
'@storybook/addon-links': 10.3.5
'@storybook/addon-onboarding': 10.3.5
'@storybook/addon-themes': 10.3.5
'@storybook/nextjs-vite': 10.3.5
'@storybook/react': 10.3.5
'@storybook/react-vite': 10.3.5
'@sentry/react': 10.51.0
'@storybook/addon-docs': 10.3.6
'@storybook/addon-links': 10.3.6
'@storybook/addon-onboarding': 10.3.6
'@storybook/addon-themes': 10.3.6
'@storybook/nextjs-vite': 10.3.6
'@storybook/react': 10.3.6
'@storybook/react-vite': 10.3.6
'@streamdown/math': 1.0.2
'@svgdotjs/svg.js': 3.2.5
'@t3-oss/env-nextjs': 0.13.11
'@tailwindcss/postcss': 4.2.4
'@tailwindcss/typography': 0.5.19
'@tailwindcss/vite': 4.2.4
'@tanstack/eslint-plugin-query': 5.100.6
'@tanstack/eslint-plugin-query': 5.100.9
'@tanstack/react-devtools': 0.10.2
'@tanstack/react-form': 1.29.1
'@tanstack/react-form-devtools': 0.2.22
'@tanstack/react-hotkeys': 0.10.0
'@tanstack/react-query': 5.100.6
'@tanstack/react-query-devtools': 5.100.6
'@tanstack/react-query': 5.100.9
'@tanstack/react-query-devtools': 5.100.9
'@tanstack/react-virtual': 3.13.24
'@testing-library/dom': 10.4.1
'@testing-library/jest-dom': 6.9.1
'@testing-library/react': 16.3.2
'@testing-library/user-event': 14.6.1
'@tsslint/cli': 3.1.0
'@tsslint/compat-eslint': 3.1.0
'@tsslint/config': 3.1.0
'@tsslint/cli': 3.1.1
'@tsslint/compat-eslint': 3.1.1
'@tsslint/config': 3.1.1
'@types/js-cookie': 3.0.6
'@types/js-yaml': 4.0.9
'@types/negotiator': 0.6.4
@ -129,9 +129,9 @@ catalog:
'@types/react': 19.2.14
'@types/react-dom': 19.2.3
'@types/sortablejs': 1.15.9
'@typescript-eslint/eslint-plugin': 8.59.1
'@typescript-eslint/parser': 8.59.1
'@typescript/native-preview': 7.0.0-dev.20260428.1
'@typescript-eslint/eslint-plugin': 8.59.2
'@typescript-eslint/parser': 8.59.2
'@typescript/native-preview': 7.0.0-dev.20260505.1
'@vitejs/plugin-react': 6.0.1
'@vitejs/plugin-rsc': 0.5.25
'@vitest/coverage-v8': 4.1.5
@ -149,44 +149,45 @@ catalog:
cron-parser: 5.5.0
dayjs: 1.11.20
decimal.js: 10.6.0
dompurify: 3.4.1
dompurify: 3.4.2
echarts: 6.0.0
echarts-for-react: 3.0.6
elkjs: 0.11.1
embla-carousel-autoplay: 8.6.0
embla-carousel-react: 8.6.0
emoji-mart: 5.6.0
es-toolkit: 1.46.0
eslint: 10.2.1
eslint-markdown: 0.7.0
es-toolkit: 1.46.1
eslint: 10.3.0
eslint-markdown: 0.8.0
eslint-plugin-better-tailwindcss: 4.5.0
eslint-plugin-hyoban: 0.14.1
eslint-plugin-markdown-preferences: 0.41.1
eslint-plugin-no-barrel-files: 1.3.1
eslint-plugin-react-refresh: 0.5.2
eslint-plugin-sonarjs: 4.0.3
eslint-plugin-storybook: 10.3.5
eslint-plugin-storybook: 10.3.6
fast-deep-equal: 3.1.3
fuse.js: 7.2.0
happy-dom: 20.9.0
hast-util-to-jsx-runtime: 2.3.6
hono: 4.12.15
hono: 4.12.17
html-entities: 2.6.0
html-to-image: 1.11.13
i18next: 26.0.8
i18next-resources-to-backend: 1.2.1
iconify-import-svg: 0.2.0
immer: 11.1.4
jotai: 2.19.1
immer: 11.1.6
jotai: 2.20.0
js-audio-recorder: 1.0.7
js-cookie: 3.0.5
js-yaml: 4.1.1
jsonschema: 1.5.0
katex: 0.16.45
knip: 6.7.0
knip: 6.11.0
ky: 2.0.2
lamejs: 1.2.1
lexical: 0.44.0
loro-crdt: 1.12.0
loro-crdt: 1.12.1
mermaid: 11.14.0
mime: 4.1.0
mitt: 3.0.1
@ -196,14 +197,14 @@ catalog:
nuqs: 2.8.9
pinyin-pro: 3.28.1
playwright: 1.59.1
postcss: 8.5.12
postcss: 8.5.14
qrcode.react: 4.2.0
qs: 6.15.1
react: 19.2.5
react-18-input-autosize: 3.0.0
react-dom: 19.2.5
react-easy-crop: 5.5.7
react-hotkeys-hook: 5.2.4
react-hotkeys-hook: 5.3.2
react-i18next: 16.5.8
react-multi-email: 1.0.25
react-papaparse: 4.4.0
@ -220,25 +221,25 @@ catalog:
socket.io-client: 4.8.3
sortablejs: 1.15.7
std-semver: 1.0.8
storybook: 10.3.5
storybook: 10.3.6
streamdown: 2.5.0
string-ts: 2.3.1
tailwind-merge: 3.5.0
tailwindcss: 4.2.4
tldts: 7.0.29
tldts: 7.0.30
tsx: 4.21.0
typescript: 6.0.3
uglify-js: 3.19.3
unist-util-visit: 5.1.0
use-context-selector: 2.0.0
uuid: 14.0.0
vinext: 0.0.45
vinext: 0.0.47
vite: npm:@voidzero-dev/vite-plus-core@0.1.20
vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.20
vitest: npm:@voidzero-dev/vite-plus-test@0.1.20
vitest-browser-react: 2.2.0
vitest-canvas-mock: 1.1.4
zod: 4.3.6
zod: 4.4.3
zundo: 2.3.0
zustand: 5.0.12
zustand: 5.0.13

View File

@ -5,9 +5,9 @@
## Overlay Components (Mandatory)
- `../packages/dify-ui/README.md` is the permanent contract for overlay primitives, portals, root `isolation: isolate`, and the `z-1002` / `z-1003` layering.
- `./docs/overlay-migration.md` is the source of truth for the ongoing migration (deprecated import paths, allowlist, coexistence rules).
- `./docs/overlay-migration.md` is the source of truth for the ongoing migration (deprecated import paths and coexistence rules).
- In new or modified code, use only overlay primitives from `@langgenius/dify-ui/*`.
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding).
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them.
## Query & Mutation (Mandatory)

View File

@ -0,0 +1,172 @@
import type { ReactNode } from 'react'
import * as React from 'react'
const DropdownMenuContext = React.createContext({
open: false,
onOpenChange: (_open: boolean) => {},
})
type DropdownMenuProps = {
children?: ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type DropdownMenuTriggerProps = React.HTMLAttributes<HTMLElement> & {
children?: ReactNode
nativeButton?: boolean
render?: React.ReactElement
}
type DropdownMenuContentProps = React.HTMLAttributes<HTMLDivElement> & {
children?: ReactNode
placement?: string
sideOffset?: number
alignOffset?: number
popupClassName?: string
}
export const DropdownMenu = ({
children,
open,
onOpenChange,
}: DropdownMenuProps) => {
const [localOpen, setLocalOpen] = React.useState(false)
const resolvedOpen = open ?? localOpen
const handleOpenChange = React.useCallback((nextOpen: boolean) => {
setLocalOpen(nextOpen)
onOpenChange?.(nextOpen)
}, [onOpenChange])
return (
<DropdownMenuContext.Provider value={{ open: resolvedOpen, onOpenChange: handleOpenChange }}>
<div data-testid="dropdown-menu" data-open={String(resolvedOpen)}>
{children}
</div>
</DropdownMenuContext.Provider>
)
}
export const DropdownMenuTrigger = ({
children,
render,
nativeButton: _nativeButton,
onClick,
...props
}: DropdownMenuTriggerProps) => {
const { open, onOpenChange } = React.useContext(DropdownMenuContext)
const node = render ?? children
const isNativeButton = React.isValidElement(node) && node.type === 'button'
const handleClick = (event: React.MouseEvent<HTMLElement>) => {
onClick?.(event)
if (!event.defaultPrevented)
onOpenChange(!open)
}
if (React.isValidElement(node)) {
const triggerElement = node as React.ReactElement<Record<string, unknown>>
const childProps = (triggerElement.props ?? {}) as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
const triggerProps = props as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
const role = childProps.role ?? triggerProps.role ?? (!isNativeButton && (childProps['aria-label'] || triggerProps['aria-label']) ? 'button' : undefined)
return React.cloneElement(triggerElement, {
...props,
...childProps,
'data-testid': childProps['data-testid'] ?? triggerProps['data-testid'] ?? 'dropdown-menu-trigger',
role,
'tabIndex': childProps.tabIndex ?? triggerProps.tabIndex ?? (role === 'button' ? 0 : undefined),
'onClick': (event: React.MouseEvent<HTMLElement>) => {
childProps.onClick?.(event)
handleClick(event)
},
}, render ? (children ?? childProps.children) : childProps.children)
}
return (
<div data-testid="dropdown-menu-trigger" role="button" tabIndex={0} onClick={handleClick} {...props}>
{node}
</div>
)
}
export const DropdownMenuContent = ({
children,
className,
popupClassName,
placement,
sideOffset,
alignOffset,
...props
}: DropdownMenuContentProps) => {
const { open } = React.useContext(DropdownMenuContext)
if (!open)
return null
return (
<div
data-testid="dropdown-menu-content"
data-placement={placement}
data-side-offset={sideOffset}
data-align-offset={alignOffset}
className={className || popupClassName}
{...props}
>
{children}
</div>
)
}
export const DropdownMenuItem = ({
children,
onClick,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode }) => (
<div role="menuitem" onClick={onClick} {...props}>
{children}
</div>
)
export const DropdownMenuRadioGroup = ({
children,
onValueChange,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode, value?: unknown, onValueChange?: (value: unknown) => void }) => (
<div
role="radiogroup"
{...props}
data-on-value-change={onValueChange ? 'true' : undefined}
>
{React.Children.map(children, (child) => {
if (!React.isValidElement(child))
return child
return React.cloneElement(child as React.ReactElement<{ __onValueChange?: (value: unknown) => void }>, { __onValueChange: onValueChange })
})}
</div>
)
export const DropdownMenuRadioItem = ({
children,
value,
onClick,
__onValueChange,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode, value?: unknown, __onValueChange?: (value: unknown) => void }) => (
<div
role="radio"
onClick={(event) => {
onClick?.(event)
__onValueChange?.(value)
}}
{...props}
>
{children}
</div>
)
export const DropdownMenuRadioItemIndicator = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuCheckboxItem = DropdownMenuItem
export const DropdownMenuCheckboxItemIndicator = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuLabel = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuSeparator = (props: React.HTMLAttributes<HTMLDivElement>) => <div role="separator" {...props} />
export const DropdownMenuSub = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuSubTrigger = DropdownMenuItem
export const DropdownMenuSubContent = ({ children }: { children?: ReactNode }) => <>{children}</>

View File

@ -23,17 +23,25 @@ type PopoverContentProps = React.HTMLAttributes<HTMLDivElement> & {
placement?: string
sideOffset?: number
alignOffset?: number
popupClassName?: string
positionerProps?: React.HTMLAttributes<HTMLDivElement>
popupProps?: React.HTMLAttributes<HTMLDivElement>
}
export const Popover = ({
children,
open = false,
open,
onOpenChange,
}: PopoverProps) => {
const [localOpen, setLocalOpen] = React.useState(false)
const resolvedOpen = open ?? localOpen
const handleOpenChange = React.useCallback((nextOpen: boolean) => {
setLocalOpen(nextOpen)
onOpenChange?.(nextOpen)
}, [onOpenChange])
React.useEffect(() => {
if (!open)
if (!resolvedOpen)
return
const handleMouseDown = (event: MouseEvent) => {
@ -41,12 +49,12 @@ export const Popover = ({
if (target?.closest?.('[data-popover-trigger="true"], [data-popover-content="true"]'))
return
onOpenChange?.(false)
handleOpenChange(false)
}
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape')
onOpenChange?.(false)
handleOpenChange(false)
}
document.addEventListener('mousedown', handleMouseDown)
@ -56,15 +64,15 @@ export const Popover = ({
document.removeEventListener('mousedown', handleMouseDown)
document.removeEventListener('keydown', handleKeyDown)
}
}, [open, onOpenChange])
}, [resolvedOpen, handleOpenChange])
return (
<PopoverContext.Provider value={{
open,
onOpenChange: onOpenChange ?? (() => {}),
open: resolvedOpen,
onOpenChange: handleOpenChange,
}}
>
<div data-testid="popover" data-open={String(open)}>
<div data-testid="popover" data-open={String(resolvedOpen)}>
{children}
</div>
</PopoverContext.Provider>
@ -84,11 +92,12 @@ export const PopoverTrigger = ({
if (React.isValidElement(node)) {
const triggerElement = node as React.ReactElement<Record<string, unknown>>
const childProps = (triggerElement.props ?? {}) as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
const triggerProps = props as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
return React.cloneElement(triggerElement, {
...props,
...childProps,
'data-testid': childProps['data-testid'] ?? 'popover-trigger',
'data-testid': childProps['data-testid'] ?? triggerProps['data-testid'] ?? 'popover-trigger',
'data-popover-trigger': 'true',
'onClick': (event: React.MouseEvent<HTMLElement>) => {
childProps.onClick?.(event)
@ -97,7 +106,7 @@ export const PopoverTrigger = ({
return
onOpenChange(!open)
},
})
}, render ? (children ?? childProps.children) : childProps.children)
}
return (
@ -123,6 +132,7 @@ export const PopoverContent = ({
placement,
sideOffset,
alignOffset,
popupClassName,
positionerProps,
popupProps,
...props
@ -139,7 +149,7 @@ export const PopoverContent = ({
data-placement={placement}
data-side-offset={sideOffset}
data-align-offset={alignOffset}
className={className}
className={className || popupClassName}
{...positionerProps}
{...popupProps}
{...props}

View File

@ -0,0 +1,65 @@
import type { ReactNode } from 'react'
import * as React from 'react'
const SelectContext = React.createContext({
value: undefined as unknown,
onValueChange: (_value: unknown) => {},
})
type SelectProps = {
children?: ReactNode
value?: unknown
onValueChange?: (value: unknown) => void
}
export const Select = ({
children,
value,
onValueChange,
}: SelectProps) => (
<SelectContext.Provider value={{ value, onValueChange: onValueChange ?? (() => {}) }}>
<div data-testid="select-root">{children}</div>
</SelectContext.Provider>
)
export const SelectTrigger = ({
children,
...props
}: React.ButtonHTMLAttributes<HTMLButtonElement> & { children?: ReactNode }) => (
<button type="button" {...props}>
{children}
</button>
)
export const SelectValue = ({ placeholder }: { placeholder?: ReactNode }) => <>{placeholder}</>
export const SelectContent = ({ children }: { children?: ReactNode }) => (
<div data-testid="select-content">{children}</div>
)
export const SelectItem = ({
children,
value,
onClick,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode, value?: unknown }) => {
const select = React.useContext(SelectContext)
return (
<div
role="option"
onClick={(event) => {
onClick?.(event)
select.onValueChange(value)
}}
{...props}
>
{children}
</div>
)
}
export const SelectItemText = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectItemIndicator = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectGroup = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectLabel = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectSeparator = (props: React.HTMLAttributes<HTMLDivElement>) => <div role="separator" {...props} />

View File

@ -0,0 +1,95 @@
import type { ReactNode } from 'react'
import * as React from 'react'
const TooltipContext = React.createContext({
open: false,
onOpenChange: (_open: boolean) => {},
})
type TooltipProps = {
children?: ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
export const Tooltip = ({ children, open, onOpenChange }: TooltipProps) => {
const [localOpen, setLocalOpen] = React.useState(false)
const resolvedOpen = open ?? localOpen
const handleOpenChange = React.useCallback((nextOpen: boolean) => {
setLocalOpen(nextOpen)
onOpenChange?.(nextOpen)
}, [onOpenChange])
return (
<TooltipContext.Provider value={{ open: resolvedOpen, onOpenChange: handleOpenChange }}>
{children}
</TooltipContext.Provider>
)
}
export const TooltipTrigger = ({
children,
render,
nativeButton: _nativeButton,
...props
}: React.HTMLAttributes<HTMLElement> & { children?: ReactNode, render?: React.ReactElement, nativeButton?: boolean }) => {
const { open, onOpenChange } = React.useContext(TooltipContext)
const node = render ?? children
if (React.isValidElement(node)) {
const triggerElement = node as React.ReactElement<Record<string, unknown>>
const childProps = (triggerElement.props ?? {}) as React.HTMLAttributes<HTMLElement>
return React.cloneElement(triggerElement, {
...props,
...childProps,
onMouseEnter: (event: React.MouseEvent<HTMLElement>) => {
childProps.onMouseEnter?.(event)
props.onMouseEnter?.(event)
onOpenChange(true)
},
onMouseLeave: (event: React.MouseEvent<HTMLElement>) => {
childProps.onMouseLeave?.(event)
props.onMouseLeave?.(event)
onOpenChange(false)
},
onClick: (event: React.MouseEvent<HTMLElement>) => {
childProps.onClick?.(event)
props.onClick?.(event)
onOpenChange(!open)
},
})
}
return (
<span
{...props}
onMouseEnter={(event) => {
props.onMouseEnter?.(event)
onOpenChange(true)
}}
onMouseLeave={(event) => {
props.onMouseLeave?.(event)
onOpenChange(false)
}}
onClick={(event) => {
props.onClick?.(event)
onOpenChange(!open)
}}
>
{node}
</span>
)
}
export const TooltipContent = ({
children,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode }) => {
const { open } = React.useContext(TooltipContext)
if (!open)
return null
return <div {...props}>{children}</div>
}
export const TooltipProvider = ({ children }: { children?: ReactNode }) => <>{children}</>

View File

@ -95,37 +95,8 @@ vi.mock('@/app/components/workflow/utils', () => ({
getKeyboardKeyNameBySystem: (key: string) => key,
}))
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const React = await vi.importActual<typeof import('react')>('react')
const OpenContext = React.createContext(false)
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<OpenContext.Provider value={open}>
<div>{children}</div>
</OpenContext.Provider>
),
PortalToFollowElemTrigger: ({
children,
onClick,
}: {
children: React.ReactNode
onClick?: () => void
}) => (
<button type="button" data-testid="portal-trigger" onClick={onClick}>
{children}
</button>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = React.useContext(OpenContext)
return open ? <div>{children}</div> : null
},
}
})
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
}))
vi.mock('@langgenius/dify-ui/dropdown-menu', () => import('@/__mocks__/base-ui-dropdown-menu'))
vi.mock('@langgenius/dify-ui/tooltip', () => import('@/__mocks__/base-ui-tooltip'))
vi.mock('@/app/components/app-sidebar/app-info', () => ({
default: ({

View File

@ -122,33 +122,7 @@ vi.mock('@/app/components/app/app-access-control', () => ({
default: () => <div data-testid="app-access-control" />,
}))
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const React = await vi.importActual<typeof import('react')>('react')
const OpenContext = React.createContext(false)
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<OpenContext.Provider value={open}>
<div>{children}</div>
</OpenContext.Provider>
),
PortalToFollowElemTrigger: ({
children,
onClick,
}: {
children: React.ReactNode
onClick?: () => void
}) => (
<div onClick={onClick}>
{children}
</div>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = React.useContext(OpenContext)
return open ? <div>{children}</div> : null
},
}
})
vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover'))
vi.mock('@/app/components/workflow/utils', () => ({
getKeyboardKeyCodeBySystem: () => 'ctrl',

View File

@ -53,6 +53,16 @@ vi.mock('@/next/navigation', () => ({
}),
}))
vi.mock('@tanstack/react-query', async (importOriginal) => {
const actual = await importOriginal<typeof import('@tanstack/react-query')>()
return {
...actual,
useQuery: () => ({
data: [],
}),
}
})
// Mock headless UI Popover so it renders content without transition
vi.mock('@headlessui/react', async () => {
const actual = await vi.importActual<typeof import('@headlessui/react')>('@headlessui/react')

View File

@ -9,7 +9,7 @@ import type { ReactElement, ReactNode } from 'react'
*/
import type { AppListResponse } from '@/models/app'
import type { App } from '@/types/app'
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createSystemFeaturesWrapper } from '@/__tests__/utils/mock-system-features'
import List from '@/app/components/apps/list'
@ -92,6 +92,9 @@ vi.mock('@tanstack/react-query', async (importOriginal) => {
const actual = await importOriginal<typeof import('@tanstack/react-query')>()
return {
...actual,
useQuery: () => ({
data: [],
}),
useInfiniteQuery: () => ({
data: { pages: mockPages },
isLoading: mockIsLoading,
@ -360,13 +363,18 @@ describe('App List Browsing Flow', () => {
expect(input).toBeInTheDocument()
})
it('should allow typing in search input', () => {
it('should update search query when typing in search input', async () => {
mockPages = [createPage([createMockApp()])]
renderList()
const { onUrlUpdate } = renderList()
const input = document.querySelector('input')!
const input = screen.getByPlaceholderText('common.operation.search')
fireEvent.change(input, { target: { value: 'test search' } })
expect(input.value).toBe('test search')
await waitFor(() => {
expect(onUrlUpdate).toHaveBeenCalled()
})
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(lastCall.searchParams.get('keywords')).toBe('test search')
})
})

View File

@ -186,7 +186,12 @@ describe('Human input share form', () => {
action: 'approve',
inputs: {
summary: 'updated summary',
attachments: [mockContentItemState.uploadedFile],
attachments: [{
type: 'document',
transfer_method: TransferMethod.local_file,
url: '',
upload_file_id: 'upload-file-1',
}],
},
},
}, expect.objectContaining({
@ -208,7 +213,12 @@ describe('Human input share form', () => {
action: 'approve',
inputs: {
summary: 'initial summary',
attachments: [mockContentItemState.uploadedFile],
attachments: [{
type: 'document',
transfer_method: TransferMethod.local_file,
url: '',
upload_file_id: 'upload-file-1',
}],
},
},
}, expect.objectContaining({

View File

@ -16,7 +16,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo'
type LoadedFormContentProps = {
formData: FormData
isSubmitting: boolean
onSubmit: (inputs: Record<string, HumanInputFieldValue>, actionID: string) => void
onSubmit: (inputs: Record<string, HumanInputFieldValue>, actionID: string, formInputs: FormData['inputs']) => void
}
const LoadedFormContent = ({
@ -40,7 +40,7 @@ const LoadedFormContent = ({
}
const submit = (actionID: string) => {
onSubmit(inputs, actionID)
onSubmit(inputs, actionID, formData.inputs)
}
const isActionDisabled = isSubmitting || hasInvalidRequiredHumanInput(formData.inputs, inputs)

View File

@ -1,14 +1,22 @@
import type { HumanInputFieldValue } from '@/app/components/base/chat/chat/answer/human-input-content/field-renderer'
import type { FormInputItem } from '@/app/components/workflow/nodes/human-input/types'
import { useCallback, useState } from 'react'
import { getProcessedHumanInputFormInputs } from '@/app/components/base/chat/chat/answer/human-input-content/utils'
import { useSubmitHumanInputForm } from '@/service/use-share'
export const useFormSubmit = (token: string) => {
const [success, setSuccess] = useState(false)
const { mutate: submitForm, isPending: isSubmitting } = useSubmitHumanInputForm()
const submit = useCallback((inputs: Record<string, HumanInputFieldValue>, actionID: string) => {
const submit = useCallback((inputs: Record<string, HumanInputFieldValue>, actionID: string, formInputs: FormInputItem[]) => {
submitForm(
{ token, data: { inputs, action: actionID } },
{
token,
data: {
inputs: getProcessedHumanInputFormInputs(formInputs, inputs) || {},
action: actionID,
},
},
{
onSuccess: () => {
setSuccess(true)

View File

@ -121,25 +121,7 @@ vi.mock('../../app-access-control', () => ({
),
}))
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const ReactModule = await vi.importActual<typeof import('react')>('react')
const OpenContext = ReactModule.createContext(false)
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<OpenContext.Provider value={open}>
<div>{children}</div>
</OpenContext.Provider>
),
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => (
<div onClick={onClick}>{children}</div>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = ReactModule.useContext(OpenContext)
return open ? <div>{children}</div> : null
},
}
})
vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover'))
vi.mock('../sections', () => ({
PublisherSummarySection: (props: Record<string, any>) => {

View File

@ -1,8 +1,8 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import SelectVarType from '../select-var-type'
describe('SelectVarType', () => {
it('should open the menu and return the selected variable type', () => {
it('should open the menu and return the selected variable type', async () => {
const onChange = vi.fn()
render(<SelectVarType onChange={onChange} />)
@ -11,6 +11,8 @@ describe('SelectVarType', () => {
fireEvent.click(screen.getByText('appDebug.variableConfig.checkbox'))
expect(onChange).toHaveBeenCalledWith('checkbox')
expect(screen.queryByText('appDebug.variableConfig.checkbox')).not.toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByText('appDebug.variableConfig.checkbox')).not.toBeInTheDocument()
})
})
})

View File

@ -3,6 +3,8 @@ import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import TypeSelector from '../type-select'
vi.mock('@langgenius/dify-ui/select', () => import('@/__mocks__/base-ui-select'))
vi.mock('@/app/components/workflow/nodes/_base/components/input-var-type-icon', () => ({
default: ({ type }: { type: string }) => <span>{type}</span>,
}))
@ -25,7 +27,7 @@ describe('TypeSelector', () => {
await user.click(screen.getByRole('combobox'))
const [, numberOption] = await screen.findAllByRole('option')
await user.click(numberOption)
await user.click(numberOption!)
expect(onSelect).toHaveBeenCalledWith({ value: 'number', name: 'Number' })
})
@ -47,7 +49,7 @@ describe('TypeSelector', () => {
await user.click(screen.getByRole('combobox'))
const [, numberOption] = await screen.findAllByRole('option')
const popup = numberOption.closest('[data-side]')
const popup = numberOption!.closest('[data-side]')
expect(popup).toHaveClass('w-(--anchor-width)')
})

View File

@ -2,7 +2,14 @@
import type { FC } from 'react'
import type { InputVarType } from '@/app/components/workflow/types'
import { cn } from '@langgenius/dify-ui/cn'
import { Select, SelectContent, SelectItem, SelectItemText, SelectTrigger } from '@langgenius/dify-ui/select'
import {
Select,
SelectContent,
SelectItem,
SelectItemIndicator,
SelectItemText,
SelectTrigger,
} from '@langgenius/dify-ui/select'
import * as React from 'react'
import Badge from '@/app/components/base/badge'
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
@ -26,28 +33,30 @@ const TypeSelector: FC<Props> = ({
value,
onSelect,
items,
popupClassName,
popupInnerClassName,
readonly,
}) => {
const selectedItem = value ? items.find(item => `${item.value}` === `${value}`) : undefined
const selectedItem = value ? items.find(item => item.value === value) : undefined
return (
<Select
value={selectedItem ? `${selectedItem.value}` : null}
value={selectedItem?.value}
readOnly={readonly}
onValueChange={(nextValue) => {
if (!nextValue)
return
const nextItem = items.find(item => `${item.value}` === nextValue)
if (nextItem)
onSelect(nextItem)
const selected = items.find(item => item.value === nextValue)
if (selected)
onSelect(selected)
}}
>
<SelectTrigger className="h-9 rounded-lg px-2 text-sm" title={selectedItem?.name}>
<div className="flex w-full items-center justify-between gap-2">
<div className="flex min-w-0 items-center">
<SelectTrigger
className={cn(
'h-9 rounded-lg px-2 text-sm',
readonly ? 'cursor-not-allowed' : 'cursor-pointer',
)}
title={selectedItem?.name}
>
<div className="flex min-w-0 items-center justify-between">
<div className="flex items-center">
<InputVarTypeIcon type={selectedItem?.value as InputVarType} className="size-4 shrink-0 text-text-secondary" />
<span
className={cn(
@ -58,32 +67,31 @@ const TypeSelector: FC<Props> = ({
{selectedItem?.name}
</span>
</div>
<div className="flex shrink-0 items-center">
<div className="ml-2 flex shrink-0 items-center space-x-1">
<Badge uppercase={false}>{inputVarTypeToVarType(selectedItem?.value as InputVarType)}</Badge>
</div>
</div>
</SelectTrigger>
<SelectContent
placement="bottom-start"
sideOffset={4}
className={popupClassName}
popupClassName={cn('w-(--anchor-width) text-base sm:text-sm', popupInnerClassName)}
listClassName="p-1"
popupClassName={cn('w-[432px] rounded-md px-1 py-1 text-base sm:text-sm', popupInnerClassName)}
listClassName="max-h-80 p-0"
>
{items.map((item: Item) => (
<SelectItem
key={item.value}
value={`${item.value}`}
className="h-9 justify-between px-2"
value={item.value}
className="h-9 justify-between px-2 text-text-secondary"
title={item.name}
>
<div className="flex w-full items-center justify-between gap-2">
<div className="flex min-w-0 items-center space-x-2">
<InputVarTypeIcon type={item.value} className="size-4 shrink-0 text-text-secondary" />
<SelectItemText title={item.name} className="mr-0 px-0">{item.name}</SelectItemText>
</div>
<Badge uppercase={false}>{inputVarTypeToVarType(item.value)}</Badge>
</div>
<SelectItemText
className="flex items-center space-x-2 px-0"
>
<InputVarTypeIcon type={item.value} className="size-4 shrink-0 text-text-secondary" />
<span title={item.name}>{item.name}</span>
</SelectItemText>
<Badge uppercase={false}>{inputVarTypeToVarType(item.value)}</Badge>
<SelectItemIndicator />
</SelectItem>
))}
</SelectContent>

View File

@ -1,15 +1,16 @@
'use client'
import type { FC } from 'react'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from '@langgenius/dify-ui/dropdown-menu'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import OperationBtn from '@/app/components/app/configuration/base/operation-btn'
import { ApiConnection } from '@/app/components/base/icons/src/vender/solid/development'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
import { InputVarType } from '@/app/components/workflow/types'
@ -27,13 +28,14 @@ type ItemProps = {
const SelectItem: FC<ItemProps> = ({ text, type, value, Icon, onClick }) => {
return (
<div
className="flex h-8 cursor-pointer items-center rounded-lg px-3 hover:bg-state-base-hover"
<DropdownMenuItem
closeOnClick
className="h-8 rounded-lg px-3 text-text-primary"
onClick={() => onClick(value)}
>
{Icon ? <Icon className="h-4 w-4 text-text-secondary" /> : <InputVarTypeIcon type={type!} className="h-4 w-4 text-text-secondary" />}
<div className="ml-2 truncate text-xs text-text-primary">{text}</div>
</div>
</DropdownMenuItem>
)
}
@ -41,40 +43,36 @@ const SelectVarType: FC<Props> = ({
onChange,
}) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const handleChange = (value: string) => {
onChange(value)
setOpen(false)
}
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement="bottom-end"
offset={{
mainAxis: 8,
crossAxis: -2,
}}
>
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<DropdownMenu>
<DropdownMenuTrigger
nativeButton={false}
render={<div className="block" />}
>
<OperationBtn type="add" />
</PortalToFollowElemTrigger>
<PortalToFollowElemContent style={{ zIndex: 1000 }}>
<div className="min-w-[192px] rounded-lg border border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-xs">
<div className="p-1">
<SelectItem type={InputVarType.textInput} value="string" text={t('variableConfig.string', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.paragraph} value="paragraph" text={t('variableConfig.paragraph', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.select} value="select" text={t('variableConfig.select', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.number} value="number" text={t('variableConfig.number', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.checkbox} value="checkbox" text={t('variableConfig.checkbox', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
<div className="h-px border-t border-components-panel-border"></div>
<div className="p-1">
<SelectItem Icon={ApiConnection} value="api" text={t('variableConfig.apiBasedVar', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
</DropdownMenuTrigger>
<DropdownMenuContent
placement="bottom-end"
sideOffset={8}
alignOffset={-2}
popupClassName="min-w-[192px] rounded-lg border bg-components-panel-bg-blur p-0 backdrop-blur-xs"
>
<div className="p-1">
<SelectItem type={InputVarType.textInput} value="string" text={t('variableConfig.string', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.paragraph} value="paragraph" text={t('variableConfig.paragraph', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.select} value="select" text={t('variableConfig.select', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.number} value="number" text={t('variableConfig.number', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.checkbox} value="checkbox" text={t('variableConfig.checkbox', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<DropdownMenuSeparator className="my-0" />
<div className="p-1">
<SelectItem Icon={ApiConnection} value="api" text={t('variableConfig.apiBasedVar', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
</DropdownMenuContent>
</DropdownMenu>
)
}
export default React.memo(SelectVarType)

View File

@ -841,11 +841,11 @@ describe('AssistantTypePicker', () => {
it('should have proper ARIA state for dropdown', async () => {
// Arrange
const user = userEvent.setup()
const { container } = renderComponent()
renderComponent()
// Act - Check initial state
const portalContainer = container.querySelector('[data-state]')
expect(portalContainer)!.toHaveAttribute('data-state', 'closed')
const triggerButton = screen.getByRole('button', { name: /chatAssistant\.name/i })
expect(triggerButton).toHaveAttribute('aria-expanded', 'false')
// Open dropdown
const trigger = screen.getByText(/chatAssistant.name/i)
@ -853,23 +853,22 @@ describe('AssistantTypePicker', () => {
// Assert - State should change to open
await waitFor(() => {
const openPortal = container.querySelector('[data-state="open"]')
expect(openPortal)!.toBeInTheDocument()
expect(triggerButton).toHaveAttribute('aria-expanded', 'true')
})
})
it('should have proper data-state attribute', () => {
// Arrange & Act
const { container } = renderComponent()
renderComponent()
// Assert - Portal should have data-state for accessibility
const portalContainer = container.querySelector('[data-state]')
expect(portalContainer)!.toBeInTheDocument()
expect(portalContainer)!.toHaveAttribute('data-state')
// Assert - Trigger should expose expanded state for accessibility
const triggerButton = screen.getByRole('button', { name: /chatAssistant\.name/i })
expect(triggerButton).toBeInTheDocument()
expect(triggerButton).toHaveAttribute('aria-expanded')
// Should start in closed state
// Should start in closed state
expect(portalContainer)!.toHaveAttribute('data-state', 'closed')
expect(triggerButton).toHaveAttribute('aria-expanded', 'false')
})
it('should maintain accessible structure for screen readers', () => {

View File

@ -2,6 +2,11 @@
import type { FC } from 'react'
import type { AgentConfig } from '@/models/debug'
import { cn } from '@langgenius/dify-ui/cn'
import {
Popover,
PopoverContent,
PopoverTrigger,
} from '@langgenius/dify-ui/popover'
import { RiArrowDownSLine } from '@remixicon/react'
import * as React from 'react'
import { useState } from 'react'
@ -10,11 +15,6 @@ import { ArrowUpRight } from '@/app/components/base/icons/src/vender/line/arrows
import { Settings04 } from '@/app/components/base/icons/src/vender/line/general'
import { CuteRobot } from '@/app/components/base/icons/src/vender/solid/communication'
import { BubbleText } from '@/app/components/base/icons/src/vender/solid/education'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import Radio from '@/app/components/base/radio/ui'
import AgentSetting from '../agent/agent-setting'
@ -107,47 +107,48 @@ const AssistantTypePicker: FC<Props> = ({
)
return (
<>
<PortalToFollowElem
<Popover
open={open}
onOpenChange={setOpen}
placement="bottom-end"
offset={{
mainAxis: 8,
crossAxis: -2,
}}
>
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<div className={cn(open && 'bg-gray-50', 'flex h-8 cursor-pointer items-center space-x-1 rounded-lg border border-black/5 px-3 text-indigo-600 select-none')}>
{isAgent ? <BubbleText className="h-3 w-3" /> : <CuteRobot className="h-3 w-3" />}
<div className="text-xs font-medium">{t(`assistantType.${isAgent ? 'agentAssistant' : 'chatAssistant'}.name`, { ns: 'appDebug' })}</div>
<RiArrowDownSLine className="h-3 w-3" />
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent style={{ zIndex: 1000 }}>
<div className="relative left-0.5 w-[480px] rounded-xl border border-black/8 bg-white p-6 shadow-lg">
<div className="mb-2 text-sm leading-5 font-semibold text-gray-900">{t('assistantType.name', { ns: 'appDebug' })}</div>
<SelectItem
Icon={BubbleText}
value="chat"
disabled={disabled}
text={t('assistantType.chatAssistant.name', { ns: 'appDebug' })}
description={t('assistantType.chatAssistant.description', { ns: 'appDebug' })}
isChecked={!isAgent}
onClick={handleChange}
/>
<SelectItem
Icon={CuteRobot}
value="agent"
disabled={disabled}
text={t('assistantType.agentAssistant.name', { ns: 'appDebug' })}
description={t('assistantType.agentAssistant.description', { ns: 'appDebug' })}
isChecked={isAgent}
onClick={handleChange}
/>
{!disabled && agentConfigUI}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<PopoverTrigger
nativeButton={false}
render={(
<div className={cn(open && 'bg-gray-50', 'flex h-8 cursor-pointer items-center space-x-1 rounded-lg border border-black/5 px-3 text-indigo-600 select-none')} />
)}
>
{isAgent ? <BubbleText className="h-3 w-3" /> : <CuteRobot className="h-3 w-3" />}
<div className="text-xs font-medium">{t(`assistantType.${isAgent ? 'agentAssistant' : 'chatAssistant'}.name`, { ns: 'appDebug' })}</div>
<RiArrowDownSLine className="h-3 w-3" />
</PopoverTrigger>
<PopoverContent
placement="bottom-end"
sideOffset={8}
alignOffset={-2}
popupClassName="relative left-0.5 w-[480px] rounded-xl border border-black/8 bg-white p-6 shadow-lg"
>
<div className="mb-2 text-sm leading-5 font-semibold text-gray-900">{t('assistantType.name', { ns: 'appDebug' })}</div>
<SelectItem
Icon={BubbleText}
value="chat"
disabled={disabled}
text={t('assistantType.chatAssistant.name', { ns: 'appDebug' })}
description={t('assistantType.chatAssistant.description', { ns: 'appDebug' })}
isChecked={!isAgent}
onClick={handleChange}
/>
<SelectItem
Icon={CuteRobot}
value="agent"
disabled={disabled}
text={t('assistantType.agentAssistant.name', { ns: 'appDebug' })}
description={t('assistantType.agentAssistant.description', { ns: 'appDebug' })}
isChecked={isAgent}
onClick={handleChange}
/>
{!disabled && agentConfigUI}
</PopoverContent>
</Popover>
{isShowAgentSetting && (
<AgentSetting
isFunctionCall={isFunctionCall}

View File

@ -1,4 +1,4 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import VersionSelector from '../version-selector'
vi.mock('react-i18next', () => ({
@ -25,7 +25,7 @@ describe('VersionSelector', () => {
expect(onChange).not.toHaveBeenCalled()
})
it('should open the selector and switch versions when multiple versions exist', () => {
it('should open the selector and switch versions when multiple versions exist', async () => {
const onChange = vi.fn()
render(
@ -44,6 +44,8 @@ describe('VersionSelector', () => {
fireEvent.click(screen.getByText('generate.version 1'))
expect(onChange).toHaveBeenCalledWith(0)
expect(screen.queryByText('generate.versions')).not.toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByText('generate.versions')).not.toBeInTheDocument()
})
})
})

View File

@ -1,10 +1,16 @@
import { cn } from '@langgenius/dify-ui/cn'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuRadioGroup,
DropdownMenuRadioItem,
DropdownMenuTrigger,
} from '@langgenius/dify-ui/dropdown-menu'
import { RiArrowDownSLine, RiCheckLine } from '@remixicon/react'
import { useBoolean } from 'ahooks'
import * as React from 'react'
import { useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '@/app/components/base/portal-to-follow-elem'
type VersionSelectorProps = {
versionLen: number
@ -16,19 +22,14 @@ const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, on
const { t } = useTranslation()
const [isOpen, {
setFalse: handleOpenFalse,
toggle: handleOpenToggle,
set: handleOpenSet,
}] = useBoolean(false)
const moreThanOneVersion = versionLen > 1
const handleOpen = useCallback((value: boolean) => {
const handleOpen = useCallback((nextOpen: boolean) => {
if (moreThanOneVersion)
handleOpenSet(value)
}, [moreThanOneVersion, handleOpenToggle])
const handleToggle = useCallback(() => {
if (moreThanOneVersion)
handleOpenToggle()
}, [moreThanOneVersion, handleOpenToggle])
handleOpenSet(nextOpen)
}, [moreThanOneVersion, handleOpenSet])
const versions = Array.from({ length: versionLen }, (_, index) => ({
label: `${t('generate.version', { ns: 'appDebug' })} ${index + 1}${index === versionLen - 1 ? ` · ${t('generate.latest', { ns: 'appDebug' })}` : ''}`,
@ -38,67 +39,59 @@ const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, on
const isLatest = value === versionLen - 1
return (
<PortalToFollowElem
placement="bottom-start"
offset={{
mainAxis: 4,
crossAxis: -12,
}}
<DropdownMenu
open={isOpen}
onOpenChange={handleOpen}
>
<PortalToFollowElemTrigger
onClick={handleToggle}
asChild
<DropdownMenuTrigger
nativeButton={false}
render={(
<div className={cn('flex items-center system-xs-medium text-text-tertiary', isOpen && 'text-text-secondary', moreThanOneVersion && 'cursor-pointer')} />
)}
>
<div className={cn('flex items-center system-xs-medium text-text-tertiary', isOpen && 'text-text-secondary', moreThanOneVersion && 'cursor-pointer')}>
<div>
{t('generate.version', { ns: 'appDebug' })}
{' '}
{value + 1}
{isLatest && ` · ${t('generate.latest', { ns: 'appDebug' })}`}
</div>
{moreThanOneVersion && <RiArrowDownSLine className="size-3" />}
<div>
{t('generate.version', { ns: 'appDebug' })}
{' '}
{value + 1}
{isLatest && ` · ${t('generate.latest', { ns: 'appDebug' })}`}
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={cn(
'z-99',
)}
{moreThanOneVersion && <RiArrowDownSLine className="size-3" />}
</DropdownMenuTrigger>
<DropdownMenuContent
placement="bottom-start"
sideOffset={4}
alignOffset={-12}
popupClassName="w-[208px] rounded-xl border-[0.5px] bg-components-panel-bg-blur p-1"
>
<div
className={cn(
'w-[208px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg',
)}
<div className={cn('flex h-[22px] items-center px-3 pl-3 system-xs-medium-uppercase text-text-tertiary')}>
{t('generate.versions', { ns: 'appDebug' })}
</div>
<DropdownMenuRadioGroup
value={value}
onValueChange={(nextValue) => {
onChange(nextValue)
handleOpenFalse()
}}
>
<div className={cn('flex h-[22px] items-center px-3 pl-3 system-xs-medium-uppercase text-text-tertiary')}>
{t('generate.versions', { ns: 'appDebug' })}
</div>
{
versions.map(option => (
<div
key={option.value}
className={cn(
'flex h-7 cursor-pointer items-center rounded-lg px-2 system-sm-medium text-text-secondary hover:bg-state-base-hover',
)}
title={option.label}
onClick={() => {
onChange(option.value)
handleOpenFalse()
}}
>
<div className="mr-1 grow truncate px-1 pl-1">
{option.label}
</div>
{
value === option.value && <RiCheckLine className="h-4 w-4 shrink-0 text-text-accent" />
}
{versions.map(option => (
<DropdownMenuRadioItem
key={option.value}
value={option.value}
closeOnClick
className="h-7 rounded-lg px-2 system-sm-medium text-text-secondary"
title={option.label}
>
<div className="mr-1 grow truncate px-1 pl-1">
{option.label}
</div>
))
}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
{
value === option.value && <RiCheckLine className="h-4 w-4 shrink-0 text-text-accent" />
}
</DropdownMenuRadioItem>
))}
</DropdownMenuRadioGroup>
</DropdownMenuContent>
</DropdownMenu>
)
}

View File

@ -1,7 +1,7 @@
import type * as React from 'react'
import type { Props } from '../var-picker'
import { render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
import ContextVar from '../index'
// Mock external dependencies only
@ -76,57 +76,6 @@ vi.mock('@langgenius/dify-ui/popover', async () => {
}
})
type PortalToFollowElemProps = {
children: React.ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type PortalToFollowElemTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode, asChild?: boolean }
type PortalToFollowElemContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
vi.mock('@/app/components/base/portal-to-follow-elem', () => {
const PortalContext = React.createContext({ open: false })
const PortalToFollowElem = ({ children, open }: PortalToFollowElemProps) => {
return (
<PortalContext.Provider value={{ open: !!open }}>
<div data-testid="portal">{children}</div>
</PortalContext.Provider>
)
}
const PortalToFollowElemContent = ({ children, ...props }: PortalToFollowElemContentProps) => {
const { open } = React.useContext(PortalContext)
if (!open)
return null
return (
<div data-testid="portal-content" {...props}>
{children}
</div>
)
}
const PortalToFollowElemTrigger = ({ children, asChild, ...props }: PortalToFollowElemTriggerProps) => {
if (asChild && React.isValidElement(children)) {
return React.cloneElement(children, {
...props,
'data-testid': 'portal-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
return (
<div data-testid="portal-trigger" {...props}>
{children}
</div>
)
}
return {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
}
})
describe('ContextVar', () => {
const mockOptions: Props['options'] = [
{ name: 'Variable 1', value: 'var1', type: 'string' },

View File

@ -10,41 +10,41 @@ vi.mock('@/next/navigation', () => ({
usePathname: () => '/test',
}))
type PortalToFollowElemProps = {
type PopoverProps = {
children: React.ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type PortalToFollowElemTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode, asChild?: boolean }
type PortalToFollowElemContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
type PopoverTriggerProps = React.HTMLAttributes<HTMLElement> & { children?: React.ReactNode, asChild?: boolean }
type PopoverContentProps = React.HTMLAttributes<HTMLDivElement> & { children?: React.ReactNode }
vi.mock('@langgenius/dify-ui/popover', () => {
const PortalContext = React.createContext({
const PopoverContext = React.createContext({
open: false,
onOpenChange: undefined as ((open: boolean) => void) | undefined,
})
const Popover = ({ children, open, onOpenChange }: PortalToFollowElemProps) => {
const Popover = ({ children, open, onOpenChange }: PopoverProps) => {
return (
<PortalContext.Provider value={{ open: !!open, onOpenChange }}>
<div data-testid="portal">{children}</div>
</PortalContext.Provider>
<PopoverContext.Provider value={{ open: !!open, onOpenChange }}>
<div data-testid="popover">{children}</div>
</PopoverContext.Provider>
)
}
const PopoverContent = ({ children, ...props }: PortalToFollowElemContentProps) => {
const { open } = React.useContext(PortalContext)
const PopoverContent = ({ children, ...props }: PopoverContentProps) => {
const { open } = React.useContext(PopoverContext)
if (!open)
return null
return (
<div data-testid="portal-content" {...props}>
<div data-testid="popover-content" {...props}>
{children}
</div>
)
}
const PopoverTrigger = ({ children, asChild, render, ...props }: PortalToFollowElemTriggerProps & { render?: React.ReactNode }) => {
const { open, onOpenChange } = React.useContext(PortalContext)
const PopoverTrigger = ({ children, asChild, render, ...props }: PopoverTriggerProps & { render?: React.ReactNode }) => {
const { open, onOpenChange } = React.useContext(PopoverContext)
const content = render ?? children
const handleClick = (e: React.MouseEvent<HTMLElement>) => {
props.onClick?.(e)
@ -56,7 +56,7 @@ vi.mock('@langgenius/dify-ui/popover', () => {
return React.cloneElement(content, {
...props,
'onClick': handleClick,
'data-testid': 'portal-trigger',
'data-testid': 'popover-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
@ -64,11 +64,11 @@ vi.mock('@langgenius/dify-ui/popover', () => {
return React.cloneElement(children, {
...props,
'onClick': handleClick,
'data-testid': 'portal-trigger',
'data-testid': 'popover-trigger',
} as React.HTMLAttributes<HTMLElement>)
}
return (
<div data-testid="portal-trigger" {...props} onClick={handleClick}>
<div data-testid="popover-trigger" {...props} onClick={handleClick}>
{content}
</div>
)
@ -109,7 +109,7 @@ describe('VarPicker', () => {
// Assert
// Assert
expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-trigger'))!.toBeInTheDocument()
expect(screen.getByText('var1'))!.toBeInTheDocument()
})
@ -201,7 +201,7 @@ describe('VarPicker', () => {
// Assert - Trigger should be present
// Assert - Trigger should be present
expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-trigger'))!.toBeInTheDocument()
})
})
@ -234,7 +234,7 @@ describe('VarPicker', () => {
// Assert
// Assert
expect(screen.getByTestId('portal-trigger'))!.toHaveClass('custom-trigger-class')
expect(screen.getByTestId('popover-trigger'))!.toHaveClass('custom-trigger-class')
})
it('should display selected value with proper formatting', () => {
@ -268,11 +268,11 @@ describe('VarPicker', () => {
// Act
render(<VarPicker {...props} />)
await user.click(screen.getByTestId('portal-trigger'))
await user.click(screen.getByTestId('popover-trigger'))
// Assert
// Assert
expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-content'))!.toBeInTheDocument()
})
it('should call onChange and close dropdown when selecting an option', async () => {
@ -285,8 +285,8 @@ describe('VarPicker', () => {
render(<VarPicker {...props} />)
// Open dropdown
await user.click(screen.getByTestId('portal-trigger'))
expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
await user.click(screen.getByTestId('popover-trigger'))
expect(screen.getByTestId('popover-content'))!.toBeInTheDocument()
// Select a different option
const options = screen.getAllByText('var2')
@ -295,7 +295,7 @@ describe('VarPicker', () => {
// Assert
expect(onChange).toHaveBeenCalledWith('var2')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument()
})
it('should toggle dropdown when clicking trigger button multiple times', async () => {
@ -306,15 +306,15 @@ describe('VarPicker', () => {
// Act
render(<VarPicker {...props} />)
const trigger = screen.getByTestId('portal-trigger')
const trigger = screen.getByTestId('popover-trigger')
// Open dropdown
await user.click(trigger)
expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-content'))!.toBeInTheDocument()
// Close dropdown
await user.click(trigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument()
})
})
@ -359,7 +359,7 @@ describe('VarPicker', () => {
// Assert
// Assert
// Assert
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument()
})
it('should toggle dropdown state on trigger click', async () => {
@ -370,16 +370,16 @@ describe('VarPicker', () => {
// Act
render(<VarPicker {...props} />)
const trigger = screen.getByTestId('portal-trigger')
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
const trigger = screen.getByTestId('popover-trigger')
expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument()
// Open dropdown
await user.click(trigger)
expect(screen.getByTestId('portal-content'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-content'))!.toBeInTheDocument()
// Close dropdown
await user.click(trigger)
expect(screen.queryByTestId('portal-content')).not.toBeInTheDocument()
expect(screen.queryByTestId('popover-content')).not.toBeInTheDocument()
})
it('should preserve selected value when dropdown is closed without selection', async () => {
@ -391,7 +391,7 @@ describe('VarPicker', () => {
render(<VarPicker {...props} />)
// Open and close dropdown without selecting anything
const trigger = screen.getByTestId('portal-trigger')
const trigger = screen.getByTestId('popover-trigger')
await user.click(trigger)
await user.click(trigger)
@ -416,7 +416,7 @@ describe('VarPicker', () => {
// Assert
// Assert
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder'))!.toBeInTheDocument()
expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-trigger'))!.toBeInTheDocument()
})
it('should handle empty options array', () => {
@ -432,7 +432,7 @@ describe('VarPicker', () => {
// Assert
// Assert
expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-trigger'))!.toBeInTheDocument()
expect(screen.getByText('appDebug.feature.dataSet.queryVariable.choosePlaceholder'))!.toBeInTheDocument()
})
@ -485,7 +485,7 @@ describe('VarPicker', () => {
// Assert
// Assert
expect(screen.getByText('longVar'))!.toBeInTheDocument()
expect(screen.getByTestId('portal-trigger'))!.toBeInTheDocument()
expect(screen.getByTestId('popover-trigger'))!.toBeInTheDocument()
})
})
})

View File

@ -486,6 +486,15 @@ describe('AppCard', () => {
expect(screen.getByTestId('dropdown-menu')).toHaveAttribute('data-modal', 'false')
})
it('should reveal operations trigger when card receives keyboard focus', () => {
render(<AppCard app={mockApp} />)
const operationsTriggerWrapper = screen.getByTestId('dropdown-menu-trigger').closest('.absolute')
expect(operationsTriggerWrapper).toHaveClass('group-focus-within:pointer-events-auto')
expect(operationsTriggerWrapper).toHaveClass('group-focus-within:opacity-100')
expect(screen.getByTestId('dropdown-menu-trigger')).toHaveClass('focus-visible:ring-1')
})
it('should show edit option when dropdown menu is opened', async () => {
render(<AppCard app={mockApp} />)

View File

@ -425,7 +425,7 @@ const AppCard = ({ app, onlineUsers = [], onRefresh, onOpenTagManagement = () =>
e.preventDefault()
getRedirection(isCurrentWorkspaceEditor, app, push)
}}
className="group relative col-span-1 inline-flex h-[160px] cursor-pointer flex-col rounded-xl border border-solid border-components-card-border bg-components-card-bg shadow-sm transition-all duration-200 ease-in-out hover:shadow-lg"
className="group relative col-span-1 inline-flex h-[160px] cursor-pointer flex-col rounded-xl border border-solid border-components-card-border bg-components-card-bg shadow-sm transition-shadow duration-200 ease-in-out hover:shadow-lg"
>
<div className="flex h-[66px] shrink-0 grow-0 items-center gap-3 px-[14px] pt-[14px] pb-3">
<div className="relative shrink-0">
@ -524,7 +524,7 @@ const AppCard = ({ app, onlineUsers = [], onRefresh, onOpenTagManagement = () =>
'absolute top-1/2 right-[6px] flex -translate-y-1/2 items-center transition-opacity',
isOperationsMenuOpen
? 'pointer-events-auto opacity-100'
: 'pointer-events-none opacity-0 group-hover:pointer-events-auto group-hover:opacity-100',
: 'pointer-events-none opacity-0 group-focus-within:pointer-events-auto group-focus-within:opacity-100 group-hover:pointer-events-auto group-hover:opacity-100',
)}
>
<div className="mx-1 h-[14px] w-px shrink-0 bg-divider-regular" />
@ -533,7 +533,7 @@ const AppCard = ({ app, onlineUsers = [], onRefresh, onOpenTagManagement = () =>
aria-label={t('operation.more', { ns: 'common' })}
className={cn(
isOperationsMenuOpen ? 'bg-state-base-hover shadow-none' : 'bg-transparent',
'flex h-8 w-8 items-center justify-center rounded-md border-none p-2 hover:bg-state-base-hover',
'flex h-8 w-8 items-center justify-center rounded-md border-none p-2 hover:bg-state-base-hover focus-visible:bg-state-base-hover focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:ring-inset',
)}
onClick={(e) => {
e.stopPropagation()

View File

@ -40,43 +40,24 @@ vi.mock('../../embedded-chatbot/theme/theme-context', () => ({
})),
}))
// Mock PortalToFollowElem using React Context
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const React = await import('react')
const MockContext = React.createContext(false)
vi.mock('@langgenius/dify-ui/dropdown-menu', () => import('@/__mocks__/base-ui-dropdown-menu'))
vi.mock('@langgenius/dify-ui/tooltip', () => import('@/__mocks__/base-ui-tooltip'))
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => {
return (
<MockContext.Provider value={open}>
<div data-open={open}>{children}</div>
</MockContext.Provider>
)
},
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = React.useContext(MockContext)
if (!open)
return null
return <div>{children}</div>
},
PortalToFollowElemTrigger: ({ children, onClick, ...props }: { children: React.ReactNode, onClick: () => void } & React.HTMLAttributes<HTMLDivElement>) => (
<div onClick={onClick} {...props}>{children}</div>
),
}
})
// Mock Modal to avoid Headless UI issues in tests
vi.mock('@/app/components/base/modal', () => ({
default: ({ children, isShow, title }: { children: React.ReactNode, isShow: boolean, title: React.ReactNode }) => {
if (!isShow)
// Mock Dialog to avoid Base UI focus/portal behavior in tests
vi.mock('@langgenius/dify-ui/dialog', () => ({
Dialog: ({ children, open }: { children: React.ReactNode, open?: boolean }) => {
if (!open)
return null
return (
<div role="dialog" data-testid="modal">
{!!title && <div>{title}</div>}
<div data-testid="modal">
{children}
</div>
)
},
DialogContent: ({ children }: { children: React.ReactNode }) => (
<div role="dialog" data-testid="modal-content">{children}</div>
),
DialogTitle: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
}))
// Sidebar mock removed to use real component

View File

@ -16,43 +16,24 @@ vi.mock('@/app/components/base/chat/chat-with-history/inputs-form/content', () =
default: () => <div data-testid="inputs-form-content">InputsFormContent</div>,
}))
// Mock PortalToFollowElem using React Context
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const React = await import('react')
const MockContext = React.createContext(false)
vi.mock('@langgenius/dify-ui/dropdown-menu', () => import('@/__mocks__/base-ui-dropdown-menu'))
vi.mock('@langgenius/dify-ui/tooltip', () => import('@/__mocks__/base-ui-tooltip'))
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => {
return (
<MockContext.Provider value={open}>
<div data-open={open}>{children}</div>
</MockContext.Provider>
)
},
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = React.useContext(MockContext)
if (!open)
return null
return <div>{children}</div>
},
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick: () => void }) => (
<div onClick={onClick}>{children}</div>
),
}
})
// Mock Modal to avoid Headless UI issues in tests
vi.mock('@/app/components/base/modal', () => ({
default: ({ children, isShow, title }: { children: React.ReactNode, isShow: boolean, title: React.ReactNode }) => {
if (!isShow)
// Mock Dialog to avoid Base UI focus/portal behavior in tests
vi.mock('@langgenius/dify-ui/dialog', () => ({
Dialog: ({ children, open }: { children: React.ReactNode, open?: boolean }) => {
if (!open)
return null
return (
<div data-testid="modal">
{!!title && <div>{title}</div>}
{children}
</div>
)
},
DialogContent: ({ children }: { children: React.ReactNode }) => (
<div role="dialog" data-testid="modal-content">{children}</div>
),
DialogTitle: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
}))
const mockAppData: AppData = {

View File

@ -490,7 +490,7 @@ describe('Sidebar Index', () => {
render(<Sidebar />)
await user.click(screen.getByTestId('rename-1'))
expect(screen.getByTestId('modal')).toBeInTheDocument()
expect(screen.getByText('common.chat.renameConversation')).toBeInTheDocument()
})
it('should pass correct props to rename modal', async () => {
@ -499,7 +499,9 @@ describe('Sidebar Index', () => {
await user.click(screen.getByTestId('rename-1'))
// The modal should have title and save/cancel
expect(screen.getByTestId('modal')).toBeInTheDocument()
expect(screen.getByText('common.chat.renameConversation')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.save' })).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'common.operation.cancel' })).toBeInTheDocument()
})
it('should call handleRenameConversation with new name', async () => {
@ -531,13 +533,13 @@ describe('Sidebar Index', () => {
render(<Sidebar />)
await user.click(screen.getByTestId('rename-1'))
expect(screen.getByTestId('modal')).toBeInTheDocument()
expect(screen.getByText('common.chat.renameConversation')).toBeInTheDocument()
const cancelButton = screen.getByText('common.operation.cancel')
await user.click(cancelButton)
await waitFor(() => {
expect(screen.queryByTestId('modal')).not.toBeInTheDocument()
expect(screen.queryByText('common.chat.renameConversation')).not.toBeInTheDocument()
})
})
@ -882,8 +884,7 @@ describe('RenameModal', () => {
/>,
)
expect(screen.getByTestId('modal')).toBeInTheDocument()
expect(screen.getByTestId('modal-title')).toHaveTextContent('common.chat.renameConversation')
expect(screen.getByText('common.chat.renameConversation')).toBeInTheDocument()
})
it('should handle empty placeholder translation fallback', () => {

View File

@ -1,11 +1,15 @@
'use client'
import type { FC } from 'react'
import { Button } from '@langgenius/dify-ui/button'
import {
Dialog,
DialogContent,
DialogTitle,
} from '@langgenius/dify-ui/dialog'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
type IRenameModalProps = {
isShow: boolean
@ -27,24 +31,28 @@ const RenameModal: FC<IRenameModalProps> = ({
const conversationNamePlaceholder = t('chat.conversationNamePlaceholder', { ns: 'common' }) || ''
return (
<Modal
title={t('chat.renameConversation', { ns: 'common' })}
isShow={isShow}
onClose={onClose}
<Dialog
open={isShow}
onOpenChange={open => !open && onClose()}
>
<div className="mt-6 text-sm leading-[21px] font-medium text-text-primary">{t('chat.conversationName', { ns: 'common' })}</div>
<Input
className="mt-2 h-10 w-full"
value={tempName}
onChange={e => setTempName(e.target.value)}
placeholder={conversationNamePlaceholder}
/>
<DialogContent>
<DialogTitle className="title-2xl-semi-bold text-text-primary">
{t('chat.renameConversation', { ns: 'common' })}
</DialogTitle>
<div className="mt-6 text-sm leading-[21px] font-medium text-text-primary">{t('chat.conversationName', { ns: 'common' })}</div>
<Input
className="mt-2 h-10 w-full"
value={tempName}
onChange={e => setTempName(e.target.value)}
placeholder={conversationNamePlaceholder}
/>
<div className="mt-10 flex justify-end">
<Button className="mr-2 shrink-0" onClick={onClose}>{t('operation.cancel', { ns: 'common' })}</Button>
<Button variant="primary" className="shrink-0" onClick={() => onSave(tempName)} loading={saveLoading}>{t('operation.save', { ns: 'common' })}</Button>
</div>
</Modal>
<div className="mt-10 flex justify-end">
<Button className="mr-2 shrink-0" onClick={onClose}>{t('operation.cancel', { ns: 'common' })}</Button>
<Button variant="primary" className="shrink-0" onClick={() => onSave(tempName)} loading={saveLoading}>{t('operation.save', { ns: 'common' })}</Button>
</div>
</DialogContent>
</Dialog>
)
}
export default React.memo(RenameModal)

View File

@ -441,9 +441,8 @@ describe('Operation', () => {
renderOperation()
const thumbDown = screen.getByTestId('operation-bar').querySelector('.i-ri-thumb-down-line')!.closest('button')!
await user.click(thumbDown)
// Check if modal title/labels fallback works
// Check if modal title/labels fallback works
expect(screen.getByRole('tooltip'))!.toBeInTheDocument()
expect(screen.getByRole('dialog', { name: 'Provide Feedback' }))!.toBeInTheDocument()
expect(screen.getByLabelText('Feedback Content'))!.toBeInTheDocument()
mockT.mockImplementation(key => key)
})
})

View File

@ -1,13 +1,17 @@
import type { FC } from 'react'
import type { ReactElement, ReactNode } from 'react'
import type {
ChatItem,
Feedback,
} from '../../types'
import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { Dialog, DialogCloseButton, DialogContent, DialogDescription, DialogTitle } from '@langgenius/dify-ui/dialog'
import { toast } from '@langgenius/dify-ui/toast'
import { Tooltip, TooltipContent, TooltipTrigger } from '@langgenius/dify-ui/tooltip'
import copy from 'copy-to-clipboard'
import {
memo,
useId,
useMemo,
useState,
} from 'react'
@ -16,10 +20,8 @@ import EditReplyModal from '@/app/components/app/annotation/edit-annotation-moda
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
import Log from '@/app/components/base/chat/chat/log'
import AnnotationCtrlButton from '@/app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-button'
import Modal from '@/app/components/base/modal/modal'
import NewAudioButton from '@/app/components/base/new-audio-button'
import Textarea from '@/app/components/base/textarea'
import Tooltip from '@/app/components/base/tooltip'
import { useChatContext } from '../context'
type OperationProps = {
@ -33,7 +35,25 @@ type OperationProps = {
noChatInput?: boolean
}
const Operation: FC<OperationProps> = ({
type FeedbackTooltipProps = {
content: ReactNode
children: ReactElement
}
const feedbackTooltipClassName = 'max-w-[260px]'
const FeedbackTooltip = ({ content, children }: FeedbackTooltipProps) => {
return (
<Tooltip>
<TooltipTrigger render={children} />
<TooltipContent className={feedbackTooltipClassName}>
{content}
</TooltipContent>
</Tooltip>
)
}
function Operation({
item,
question,
index,
@ -42,7 +62,7 @@ const Operation: FC<OperationProps> = ({
contentWidth,
hasWorkflowProcess,
noChatInput,
}) => {
}: OperationProps) {
const { t } = useTranslation()
const {
config,
@ -68,8 +88,8 @@ const Operation: FC<OperationProps> = ({
const [userLocalFeedback, setUserLocalFeedback] = useState(feedback)
const [adminLocalFeedback, setAdminLocalFeedback] = useState(adminFeedback)
const [feedbackTarget, setFeedbackTarget] = useState<'user' | 'admin'>('user')
const feedbackTextareaId = useId()
// Separate feedback types for display
const userFeedback = feedback
const content = useMemo(() => {
@ -89,7 +109,11 @@ const Operation: FC<OperationProps> = ({
const userFeedbackLabel = t('table.header.userRate', { ns: 'appLog' }) || 'User feedback'
const adminFeedbackLabel = t('table.header.adminRate', { ns: 'appLog' }) || 'Admin feedback'
const feedbackTooltipClassName = 'max-w-[260px]'
const likeLabel = t('detail.operation.like', { ns: 'appLog' }) || 'Like'
const dislikeLabel = t('detail.operation.dislike', { ns: 'appLog' }) || 'Dislike'
const removeFeedbackLabel = t('operation.remove', { ns: 'common' }) || 'Remove'
const copyLabel = t('operation.copy', { ns: 'common' }) || 'Copy'
const regenerateLabel = t('operation.regenerate', { ns: 'common' }) || 'Regenerate'
const buildFeedbackTooltip = (feedbackData?: Feedback | null, label = userFeedbackLabel) => {
if (!feedbackData?.rating)
@ -180,33 +204,35 @@ const Operation: FC<OperationProps> = ({
>
{hasUserFeedback
? (
<Tooltip
popupContent={buildFeedbackTooltip(displayUserFeedback, userFeedbackLabel)}
popupClassName={feedbackTooltipClassName}
<FeedbackTooltip
content={buildFeedbackTooltip(displayUserFeedback, userFeedbackLabel)}
>
<ActionButton
aria-label={`${userFeedbackLabel}: ${removeFeedbackLabel}`}
state={displayUserFeedback?.rating === 'like' ? ActionButtonState.Active : ActionButtonState.Destructive}
onClick={() => handleFeedback(null, undefined, 'user')}
>
{displayUserFeedback?.rating === 'like'
? <div className="i-ri-thumb-up-line h-4 w-4" />
: <div className="i-ri-thumb-down-line h-4 w-4" />}
? <span aria-hidden="true" className="i-ri-thumb-up-line h-4 w-4" />
: <span aria-hidden="true" className="i-ri-thumb-down-line h-4 w-4" />}
</ActionButton>
</Tooltip>
</FeedbackTooltip>
)
: (
<>
<ActionButton
aria-label={`${userFeedbackLabel}: ${likeLabel}`}
state={displayUserFeedback?.rating === 'like' ? ActionButtonState.Active : ActionButtonState.Default}
onClick={() => handleLikeClick('user')}
>
<div className="i-ri-thumb-up-line h-4 w-4" />
<span aria-hidden="true" className="i-ri-thumb-up-line h-4 w-4" />
</ActionButton>
<ActionButton
aria-label={`${userFeedbackLabel}: ${dislikeLabel}`}
state={displayUserFeedback?.rating === 'dislike' ? ActionButtonState.Destructive : ActionButtonState.Default}
onClick={() => handleDislikeClick('user')}
>
<div className="i-ri-thumb-down-line h-4 w-4" />
<span aria-hidden="true" className="i-ri-thumb-down-line h-4 w-4" />
</ActionButton>
</>
)}
@ -218,68 +244,65 @@ const Operation: FC<OperationProps> = ({
(hasAdminFeedback || hasUserFeedback) ? 'flex' : 'hidden group-hover:flex',
)}
>
{/* User Feedback Display */}
{displayUserFeedback?.rating && (
<Tooltip
popupContent={buildFeedbackTooltip(displayUserFeedback, userFeedbackLabel)}
popupClassName={feedbackTooltipClassName}
<FeedbackTooltip
content={buildFeedbackTooltip(displayUserFeedback, userFeedbackLabel)}
>
{displayUserFeedback.rating === 'like'
? (
<ActionButton state={ActionButtonState.Active}>
<div className="i-ri-thumb-up-line h-4 w-4" />
<ActionButton aria-label={`${userFeedbackLabel}: ${likeLabel}`} state={ActionButtonState.Active}>
<span aria-hidden="true" className="i-ri-thumb-up-line h-4 w-4" />
</ActionButton>
)
: (
<ActionButton state={ActionButtonState.Destructive}>
<div className="i-ri-thumb-down-line h-4 w-4" />
<ActionButton aria-label={`${userFeedbackLabel}: ${dislikeLabel}`} state={ActionButtonState.Destructive}>
<span aria-hidden="true" className="i-ri-thumb-down-line h-4 w-4" />
</ActionButton>
)}
</Tooltip>
</FeedbackTooltip>
)}
{/* Admin Feedback Controls */}
{displayUserFeedback?.rating && <div className="mx-1 h-3 w-[0.5px] bg-components-actionbar-border" />}
{hasAdminFeedback
? (
<Tooltip
popupContent={buildFeedbackTooltip(adminLocalFeedback, adminFeedbackLabel)}
popupClassName={feedbackTooltipClassName}
<FeedbackTooltip
content={buildFeedbackTooltip(adminLocalFeedback, adminFeedbackLabel)}
>
<ActionButton
aria-label={`${adminFeedbackLabel}: ${removeFeedbackLabel}`}
state={adminLocalFeedback?.rating === 'like' ? ActionButtonState.Active : ActionButtonState.Destructive}
onClick={() => handleFeedback(null, undefined, 'admin')}
>
{adminLocalFeedback?.rating === 'like'
? <div className="i-ri-thumb-up-line h-4 w-4" />
: <div className="i-ri-thumb-down-line h-4 w-4" />}
? <span aria-hidden="true" className="i-ri-thumb-up-line h-4 w-4" />
: <span aria-hidden="true" className="i-ri-thumb-down-line h-4 w-4" />}
</ActionButton>
</Tooltip>
</FeedbackTooltip>
)
: (
<>
<Tooltip
popupContent={buildFeedbackTooltip(adminLocalFeedback, adminFeedbackLabel)}
popupClassName={feedbackTooltipClassName}
<FeedbackTooltip
content={buildFeedbackTooltip(adminLocalFeedback, adminFeedbackLabel)}
>
<ActionButton
aria-label={`${adminFeedbackLabel}: ${likeLabel}`}
state={adminLocalFeedback?.rating === 'like' ? ActionButtonState.Active : ActionButtonState.Default}
onClick={() => handleLikeClick('admin')}
>
<div className="i-ri-thumb-up-line h-4 w-4" />
<span aria-hidden="true" className="i-ri-thumb-up-line h-4 w-4" />
</ActionButton>
</Tooltip>
<Tooltip
popupContent={buildFeedbackTooltip(adminLocalFeedback, adminFeedbackLabel)}
popupClassName={feedbackTooltipClassName}
</FeedbackTooltip>
<FeedbackTooltip
content={buildFeedbackTooltip(adminLocalFeedback, adminFeedbackLabel)}
>
<ActionButton
aria-label={`${adminFeedbackLabel}: ${dislikeLabel}`}
state={adminLocalFeedback?.rating === 'dislike' ? ActionButtonState.Destructive : ActionButtonState.Default}
onClick={() => handleDislikeClick('admin')}
>
<div className="i-ri-thumb-down-line h-4 w-4" />
<span aria-hidden="true" className="i-ri-thumb-down-line h-4 w-4" />
</ActionButton>
</Tooltip>
</FeedbackTooltip>
</>
)}
</div>
@ -300,18 +323,19 @@ const Operation: FC<OperationProps> = ({
)}
{!humanInputFormDataList?.length && (
<ActionButton
aria-label={copyLabel}
onClick={() => {
copy(content)
toast.success(t('actionMsg.copySuccessfully', { ns: 'common' }))
}}
data-testid="copy-btn"
>
<div className="i-ri-clipboard-line h-4 w-4" />
<span aria-hidden="true" className="i-ri-clipboard-line h-4 w-4" />
</ActionButton>
)}
{!noChatInput && (
<ActionButton onClick={() => onRegenerate?.(item)} data-testid="regenerate-btn">
<div className="i-ri-reset-left-line h-4 w-4" />
<ActionButton aria-label={regenerateLabel} onClick={() => onRegenerate?.(item)} data-testid="regenerate-btn">
<span aria-hidden="true" className="i-ri-reset-left-line h-4 w-4" />
</ActionButton>
)}
{config?.supportAnnotation && config.annotation_reply?.enabled && !humanInputFormDataList?.length && (
@ -342,30 +366,56 @@ const Operation: FC<OperationProps> = ({
onRemove={() => onAnnotationRemoved?.(index)}
/>
{isShowFeedbackModal && (
<Modal
title={t('feedback.title', { ns: 'common' }) || 'Provide Feedback'}
subTitle={t('feedback.subtitle', { ns: 'common' }) || 'Please tell us what went wrong with this response'}
onClose={handleFeedbackCancel}
onConfirm={handleFeedbackSubmit}
onCancel={handleFeedbackCancel}
confirmButtonText={t('operation.submit', { ns: 'common' }) || 'Submit'}
cancelButtonText={t('operation.cancel', { ns: 'common' }) || 'Cancel'}
<Dialog
open
onOpenChange={(open) => {
if (!open)
handleFeedbackCancel()
}}
>
<div className="space-y-3">
<div>
<label className="mb-2 block system-sm-semibold text-text-secondary">
{t('feedback.content', { ns: 'common' }) || 'Feedback Content'}
</label>
<Textarea
value={feedbackContent}
onChange={e => setFeedbackContent(e.target.value)}
placeholder={t('feedback.placeholder', { ns: 'common' }) || 'Please describe what went wrong or how we can improve...'}
rows={4}
className="w-full"
/>
<DialogContent
backdropProps={{ forceRender: true }}
className="p-0"
>
<div className="flex max-h-[80dvh] flex-col">
<div className="relative shrink-0 p-6 pr-14 pb-3">
<DialogTitle className="title-2xl-semi-bold text-text-primary">
{t('feedback.title', { ns: 'common' }) || 'Provide Feedback'}
</DialogTitle>
<DialogDescription className="mt-1 system-xs-regular text-text-tertiary">
{t('feedback.subtitle', { ns: 'common' }) || 'Please tell us what went wrong with this response'}
</DialogDescription>
<DialogCloseButton className="top-5 right-5 h-8 w-8 rounded-lg" />
</div>
<div className="min-h-0 flex-1 overflow-y-auto px-6 py-3">
<label htmlFor={feedbackTextareaId} className="mb-2 block system-sm-semibold text-text-secondary">
{t('feedback.content', { ns: 'common' }) || 'Feedback Content'}
</label>
<Textarea
id={feedbackTextareaId}
name="feedback-content"
value={feedbackContent}
onChange={e => setFeedbackContent(e.target.value)}
placeholder={t('feedback.placeholder', { ns: 'common' }) || 'Please describe what went wrong or how we can improve…'}
rows={4}
className="w-full"
/>
</div>
<div className="flex shrink-0 justify-end p-6 pt-5">
<Button onClick={handleFeedbackCancel}>
{t('operation.cancel', { ns: 'common' }) || 'Cancel'}
</Button>
<Button
className="ml-2"
variant="primary"
onClick={handleFeedbackSubmit}
>
{t('operation.submit', { ns: 'common' }) || 'Submit'}
</Button>
</div>
</div>
</div>
</Modal>
</DialogContent>
</Dialog>
)}
</>
)

Some files were not shown because too many files have changed in this diff Show More