Merge branch 'main' into 4-27-app-deploy

This commit is contained in:
Stephen Zhou 2026-05-08 09:40:21 +08:00 committed by GitHub
commit 6a62403931
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
235 changed files with 10246 additions and 10484 deletions

View File

@ -109,6 +109,8 @@ jobs:
- name: Web tsslint
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
env:
NODE_OPTIONS: --max-old-space-size=4096
run: vp run lint:tss
- name: Web type check

View File

@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock
```bash
cd dify
cd docker
cp .env.example .env
docker compose up -d
./dify-compose up -d
```
On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory.
After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process.
#### Seeking help
@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases.
### Custom configurations
If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments).
### Metrics Monitoring with Grafana

View File

@ -75,14 +75,15 @@ console_ns.schema_model(
def _convert_values_to_json_serializable_object(value: Segment):
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
match value:
case FileSegment():
return value.value.model_dump()
case ArrayFileSegment():
return [i.model_dump() for i in value.value]
case SegmentGroup():
return [_convert_values_to_json_serializable_object(i) for i in value.value]
case _:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable):

View File

@ -1,9 +0,0 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)

View File

@ -90,7 +90,7 @@ class TestOAuthLogin:
mock_redirect,
mock_get_providers,
resource,
app,
app: Flask,
mock_oauth_provider,
invite_token,
expected_token,
@ -165,7 +165,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -218,7 +218,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_config.CONSOLE_WEB_URL = "http://localhost:3000"
@ -262,7 +262,7 @@ class TestOAuthCallback:
mock_tenant_service,
mock_account_service,
resource,
app,
app: Flask,
oauth_setup,
account_status,
expected_redirect,
@ -301,7 +301,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
mock_get_providers.return_value = {"github": oauth_setup["provider"]}
@ -337,7 +337,7 @@ class TestOAuthCallback:
mock_get_providers,
mock_config,
resource,
app,
app: Flask,
oauth_setup,
):
"""Defensive test for CLOSED account status handling in OAuth callback.
@ -466,7 +466,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
allow_register,
@ -505,7 +505,7 @@ class TestAccountGeneration:
mock_register_service,
mock_feature_service,
mock_get_account,
app,
app: Flask,
):
user_info = OAuthUserInfo(id="123", name="Test User", email="Upper@Example.com")
mock_feature_service.get_system_features.return_value.is_allow_register = True
@ -530,7 +530,7 @@ class TestAccountGeneration:
mock_feature_service,
mock_tenant_service,
mock_get_account,
app,
app: Flask,
user_info,
mock_account,
):

View File

@ -47,7 +47,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
):
# Arrange
@ -105,7 +105,7 @@ class TestForgotPasswordSendEmailApi:
mock_send_email,
mock_get_account,
mock_is_ip_limit,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -154,7 +154,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
"""
Test successful verification code validation.
@ -201,7 +201,7 @@ class TestForgotPasswordCheckApi:
mock_revoke_token,
mock_get_data,
mock_is_rate_limit,
app,
app: Flask,
):
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "User@Example.com", "code": "999888"}
@ -345,7 +345,7 @@ class TestForgotPasswordResetApi:
mock_get_account,
mock_revoke_token,
mock_get_data,
app,
app: Flask,
mock_account,
):
"""

View File

@ -30,7 +30,7 @@ class TestPipelineTemplateListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateListApi()
method = unwrap(api.get)
@ -54,7 +54,7 @@ class TestPipelineTemplateDetailApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -75,7 +75,7 @@ class TestPipelineTemplateDetailApi:
assert status == 200
assert response == template
def test_get_returns_404_when_template_not_found(self, app):
def test_get_returns_404_when_template_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -94,7 +94,7 @@ class TestPipelineTemplateDetailApi:
assert status == 404
assert "error" in response
def test_get_returns_404_for_customized_type_not_found(self, app):
def test_get_returns_404_for_customized_type_not_found(self, app: Flask):
api = PipelineTemplateDetailApi()
method = unwrap(api.get)
@ -119,7 +119,7 @@ class TestCustomizedPipelineTemplateApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.patch)
@ -141,7 +141,7 @@ class TestCustomizedPipelineTemplateApi:
update_mock.assert_called_once()
assert response == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.delete)
@ -156,7 +156,7 @@ class TestCustomizedPipelineTemplateApi:
delete_mock.assert_called_once_with("tpl-1")
assert response == 200
def test_post_success(self, app, db_session_with_containers: Session):
def test_post_success(self, app: Flask, db_session_with_containers: Session):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -183,7 +183,7 @@ class TestCustomizedPipelineTemplateApi:
assert status == 200
assert response == {"data": "yaml-data"}
def test_post_template_not_found(self, app):
def test_post_template_not_found(self, app: Flask):
api = CustomizedPipelineTemplateApi()
method = unwrap(api.post)
@ -197,7 +197,7 @@ class TestPublishCustomizedPipelineTemplateApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = PublishCustomizedPipelineTemplateApi()
method = unwrap(api.post)

View File

@ -36,7 +36,7 @@ class TestRagPipelineImportApi:
"name": "Test",
}
def test_post_success_200(self, app):
def test_post_success_200(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -66,7 +66,7 @@ class TestRagPipelineImportApi:
assert status == 200
assert response == {"status": "success"}
def test_post_failed_400(self, app):
def test_post_failed_400(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -96,7 +96,7 @@ class TestRagPipelineImportApi:
assert status == 400
assert response == {"status": "failed"}
def test_post_pending_202(self, app):
def test_post_pending_202(self, app: Flask):
api = RagPipelineImportApi()
method = unwrap(api.post)
@ -132,7 +132,7 @@ class TestRagPipelineImportConfirmApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_confirm_success(self, app):
def test_confirm_success(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -160,7 +160,7 @@ class TestRagPipelineImportConfirmApi:
assert status == 200
assert response == {"ok": True}
def test_confirm_failed(self, app):
def test_confirm_failed(self, app: Flask):
api = RagPipelineImportConfirmApi()
method = unwrap(api.post)
@ -194,7 +194,7 @@ class TestRagPipelineImportCheckDependenciesApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = RagPipelineImportCheckDependenciesApi()
method = unwrap(api.get)
@ -223,7 +223,7 @@ class TestRagPipelineExportApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_with_include_secret(self, app):
def test_get_with_include_secret(self, app: Flask):
api = RagPipelineExportApi()
method = unwrap(api.get)

View File

@ -391,7 +391,7 @@ class TestPublishedPipelineApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_publish_success(self, app, db_session_with_containers: Session):
def test_publish_success(self, app: Flask, db_session_with_containers: Session):
from models.dataset import Pipeline
api = PublishedRagPipelineApi()

View File

@ -55,7 +55,7 @@ class TestDataSourceApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -79,7 +79,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"][0]["is_bound"] is True
def test_get_no_bindings(self, app, patch_tenant):
def test_get_no_bindings(self, app: Flask, patch_tenant):
api = DataSourceApi()
method = unwrap(api.get)
@ -95,7 +95,7 @@ class TestDataSourceApi:
assert status == 200
assert response["data"] == []
def test_patch_enable_binding(self, app, patch_tenant, mock_engine):
def test_patch_enable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -116,7 +116,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is False
def test_patch_disable_binding(self, app, patch_tenant, mock_engine):
def test_patch_disable_binding(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -137,7 +137,7 @@ class TestDataSourceApi:
assert status == 200
assert binding.disabled is True
def test_patch_binding_not_found(self, app, patch_tenant, mock_engine):
def test_patch_binding_not_found(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -152,7 +152,7 @@ class TestDataSourceApi:
with pytest.raises(NotFound):
method(api, "b1", "enable")
def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine):
def test_patch_enable_already_enabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -169,7 +169,7 @@ class TestDataSourceApi:
with pytest.raises(ValueError):
method(api, "b1", "enable")
def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
def test_patch_disable_already_disabled(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceApi()
method = unwrap(api.patch)
@ -192,7 +192,7 @@ class TestDataSourceNotionListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_credential_not_found(self, app, patch_tenant):
def test_get_credential_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -206,7 +206,7 @@ class TestDataSourceNotionListApi:
with pytest.raises(NotFound):
method(api)
def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_no_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -247,7 +247,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine):
def test_get_success_with_dataset_id(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -300,7 +300,7 @@ class TestDataSourceNotionListApi:
assert status == 200
def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine):
def test_get_invalid_dataset_type(self, app: Flask, patch_tenant, mock_engine):
api = DataSourceNotionListApi()
method = unwrap(api.get)
@ -327,7 +327,7 @@ class TestDataSourceNotionApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_preview_success(self, app, patch_tenant):
def test_get_preview_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.get)
@ -348,7 +348,7 @@ class TestDataSourceNotionApi:
assert status == 200
def test_post_indexing_estimate_success(self, app, patch_tenant):
def test_post_indexing_estimate_success(self, app: Flask, patch_tenant):
api = DataSourceNotionApi()
method = unwrap(api.post)
@ -385,7 +385,7 @@ class TestDataSourceNotionDatasetSyncApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -408,7 +408,7 @@ class TestDataSourceNotionDatasetSyncApi:
assert status == 200
def test_get_dataset_not_found(self, app, patch_tenant):
def test_get_dataset_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDatasetSyncApi()
method = unwrap(api.get)
@ -428,7 +428,7 @@ class TestDataSourceNotionDocumentSyncApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, patch_tenant):
def test_get_success(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)
@ -451,7 +451,7 @@ class TestDataSourceNotionDocumentSyncApi:
assert status == 200
def test_get_document_not_found(self, app, patch_tenant):
def test_get_document_not_found(self, app: Flask, patch_tenant):
api = DataSourceNotionDocumentSyncApi()
method = unwrap(api.get)

View File

@ -57,7 +57,7 @@ class TestConversationListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_success(self, app, chat_app, user):
def test_get_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -82,7 +82,7 @@ class TestConversationListApi:
assert result["has_more"] is False
assert len(result["data"]) == 2
def test_last_conversation_not_exists(self, app, chat_app, user):
def test_last_conversation_not_exists(self, app: Flask, chat_app, user):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -98,7 +98,7 @@ class TestConversationListApi:
with pytest.raises(NotFound):
method(chat_app)
def test_wrong_app_mode(self, app, non_chat_app):
def test_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationListApi()
method = unwrap(api.get)
@ -112,7 +112,7 @@ class TestConversationApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_delete_success(self, app, chat_app, user):
def test_delete_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -130,7 +130,7 @@ class TestConversationApi:
assert status == 204
assert body["result"] == "success"
def test_delete_not_found(self, app, chat_app, user):
def test_delete_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -146,7 +146,7 @@ class TestConversationApi:
with pytest.raises(NotFound):
method(chat_app, "cid")
def test_delete_wrong_app_mode(self, app, non_chat_app):
def test_delete_wrong_app_mode(self, app: Flask, non_chat_app):
api = conversation_module.ConversationApi()
method = unwrap(api.delete)
@ -160,7 +160,7 @@ class TestConversationRenameApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_rename_success(self, app, chat_app, user):
def test_rename_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -179,7 +179,7 @@ class TestConversationRenameApi:
assert result["id"] == "cid"
def test_rename_not_found(self, app, chat_app, user):
def test_rename_not_found(self, app: Flask, chat_app, user):
api = conversation_module.ConversationRenameApi()
method = unwrap(api.post)
@ -201,7 +201,7 @@ class TestConversationPinApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_pin_success(self, app, chat_app, user):
def test_pin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationPinApi()
method = unwrap(api.patch)
@ -223,7 +223,7 @@ class TestConversationUnPinApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_unpin_success(self, app, chat_app, user):
def test_unpin_success(self, app: Flask, chat_app, user):
api = conversation_module.ConversationUnPinApi()
method = unwrap(api.patch)

View File

@ -49,7 +49,7 @@ class TestTriggerProviderApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = TriggerProviderIconApi()
method = unwrap(api.get)
@ -63,7 +63,7 @@ class TestTriggerProviderApis:
):
assert method(api, "github") == "icon"
def test_list_providers(self, app):
def test_list_providers(self, app: Flask):
api = TriggerProviderListApi()
method = unwrap(api.get)
@ -77,7 +77,7 @@ class TestTriggerProviderApis:
):
assert method(api) == []
def test_provider_info(self, app):
def test_provider_info(self, app: Flask):
api = TriggerProviderInfoApi()
method = unwrap(api.get)
@ -97,7 +97,7 @@ class TestTriggerSubscriptionListApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -111,7 +111,7 @@ class TestTriggerSubscriptionListApi:
):
assert method(api, "github") == []
def test_list_invalid_provider(self, app):
def test_list_invalid_provider(self, app: Flask):
api = TriggerSubscriptionListApi()
method = unwrap(api.get)
@ -132,7 +132,7 @@ class TestTriggerSubscriptionBuilderApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_create_builder(self, app):
def test_create_builder(self, app: Flask):
api = TriggerSubscriptionBuilderCreateApi()
method = unwrap(api.post)
@ -147,7 +147,7 @@ class TestTriggerSubscriptionBuilderApis:
result = method(api, "github")
assert "subscription_builder" in result
def test_get_builder(self, app):
def test_get_builder(self, app: Flask):
api = TriggerSubscriptionBuilderGetApi()
method = unwrap(api.get)
@ -160,7 +160,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_verify_builder(self, app):
def test_verify_builder(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -174,7 +174,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"ok": True}
def test_verify_builder_error(self, app):
def test_verify_builder_error(self, app: Flask):
api = TriggerSubscriptionBuilderVerifyApi()
method = unwrap(api.post)
@ -189,7 +189,7 @@ class TestTriggerSubscriptionBuilderApis:
with pytest.raises(ValueError):
method(api, "github", "b1")
def test_update_builder(self, app):
def test_update_builder(self, app: Flask):
api = TriggerSubscriptionBuilderUpdateApi()
method = unwrap(api.post)
@ -203,7 +203,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert method(api, "github", "b1") == {"id": "b1"}
def test_logs(self, app):
def test_logs(self, app: Flask):
api = TriggerSubscriptionBuilderLogsApi()
method = unwrap(api.get)
@ -220,7 +220,7 @@ class TestTriggerSubscriptionBuilderApis:
):
assert "logs" in method(api, "github", "b1")
def test_build(self, app):
def test_build(self, app: Flask):
api = TriggerSubscriptionBuilderBuildApi()
method = unwrap(api.post)
@ -240,7 +240,7 @@ class TestTriggerSubscriptionCrud:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_update_rename_only(self, app):
def test_update_rename_only(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -259,7 +259,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_update_not_found(self, app):
def test_update_not_found(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -274,7 +274,7 @@ class TestTriggerSubscriptionCrud:
with pytest.raises(NotFoundError):
method(api, "x")
def test_update_rebuild(self, app):
def test_update_rebuild(self, app: Flask):
api = TriggerSubscriptionUpdateApi()
method = unwrap(api.post)
@ -297,7 +297,7 @@ class TestTriggerSubscriptionCrud:
):
assert method(api, "s1") == 200
def test_delete_subscription(self, app):
def test_delete_subscription(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -320,7 +320,7 @@ class TestTriggerSubscriptionCrud:
assert result["result"] == "success"
def test_delete_subscription_value_error(self, app):
def test_delete_subscription_value_error(self, app: Flask):
api = TriggerSubscriptionDeleteApi()
method = unwrap(api.post)
@ -346,7 +346,7 @@ class TestTriggerOAuthApis:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_oauth_authorize_success(self, app):
def test_oauth_authorize_success(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -373,7 +373,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 200
def test_oauth_authorize_no_client(self, app):
def test_oauth_authorize_no_client(self, app: Flask):
api = TriggerOAuthAuthorizeApi()
method = unwrap(api.get)
@ -388,7 +388,7 @@ class TestTriggerOAuthApis:
with pytest.raises(NotFoundError):
method(api, "github")
def test_oauth_callback_forbidden(self, app):
def test_oauth_callback_forbidden(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -396,7 +396,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_success(self, app):
def test_oauth_callback_success(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -426,7 +426,7 @@ class TestTriggerOAuthApis:
resp = method(api, "github")
assert resp.status_code == 302
def test_oauth_callback_no_oauth_client(self, app):
def test_oauth_callback_no_oauth_client(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -450,7 +450,7 @@ class TestTriggerOAuthApis:
with pytest.raises(Forbidden):
method(api, "github")
def test_oauth_callback_empty_credentials(self, app):
def test_oauth_callback_empty_credentials(self, app: Flask):
api = TriggerOAuthCallbackApi()
method = unwrap(api.get)
@ -484,7 +484,7 @@ class TestTriggerOAuthClientManageApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_get_client(self, app):
def test_get_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.get)
@ -511,7 +511,7 @@ class TestTriggerOAuthClientManageApi:
result = method(api, "github")
assert "configured" in result
def test_post_client(self, app):
def test_post_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -525,7 +525,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_delete_client(self, app):
def test_delete_client(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.delete)
@ -539,7 +539,7 @@ class TestTriggerOAuthClientManageApi:
):
assert method(api, "github") == {"ok": True}
def test_oauth_client_post_value_error(self, app):
def test_oauth_client_post_value_error(self, app: Flask):
api = TriggerOAuthClientManageApi()
method = unwrap(api.post)
@ -560,7 +560,7 @@ class TestTriggerSubscriptionVerifyApi:
def app(self, flask_app_with_containers: Flask):
return flask_app_with_containers
def test_verify_success(self, app):
def test_verify_success(self, app: Flask):
api = TriggerSubscriptionVerifyApi()
method = unwrap(api.post)

View File

@ -291,7 +291,7 @@ class TestDatasetListApiGet:
mock_current_user,
mock_provider_mgr,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -326,7 +326,7 @@ class TestDatasetListApiPost:
mock_dataset_svc,
mock_current_user,
mock_marshal,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -352,7 +352,7 @@ class TestDatasetListApiPost:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_tenant,
):
from controllers.service_api.dataset.dataset import DatasetListApi
@ -390,7 +390,7 @@ class TestDatasetApiGet:
mock_provider_mgr,
mock_marshal,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -440,7 +440,7 @@ class TestDatasetApiGet:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -468,7 +468,7 @@ class TestDatasetApiDelete:
mock_dataset_svc,
mock_current_user,
mock_perm_svc,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -490,7 +490,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -511,7 +511,7 @@ class TestDatasetApiDelete:
self,
mock_dataset_svc,
mock_current_user,
app,
app: Flask,
mock_dataset,
):
from controllers.service_api.dataset.dataset import DatasetApi
@ -543,7 +543,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -574,7 +574,7 @@ class TestDocumentStatusApiPatch:
def test_batch_update_status_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -603,7 +603,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -636,7 +636,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -669,7 +669,7 @@ class TestDocumentStatusApiPatch:
mock_dataset_svc,
mock_current_user,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -709,7 +709,7 @@ class TestDatasetTagsApiGet:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -731,7 +731,7 @@ class TestDatasetTagsApiGet:
def test_list_tags_from_db(
self,
mock_current_user,
app,
app: Flask,
db_session_with_containers: Session,
):
"""Integration test: creates real Tag rows and retrieves them
@ -774,7 +774,7 @@ class TestDatasetTagsApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -797,7 +797,7 @@ class TestDatasetTagsApiPost:
mock_tag_svc.save_tags.assert_called_once()
@patch("controllers.service_api.dataset.dataset.current_user")
def test_create_tag_forbidden(self, mock_current_user, app):
def test_create_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -826,7 +826,7 @@ class TestDatasetTagsApiPatch:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -852,7 +852,7 @@ class TestDatasetTagsApiPatch:
mock_tag_svc.update_tags.assert_called_once_with({"name": "Updated Tag", "type": "knowledge"}, "tag-1")
@patch("controllers.service_api.dataset.dataset.current_user")
def test_update_tag_forbidden(self, mock_current_user, app):
def test_update_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
mock_current_user.__class__ = Account
@ -880,7 +880,7 @@ class TestDatasetTagsApiDelete:
mock_current_user,
mock_service_api_ns,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsApi
@ -905,7 +905,7 @@ class TestDatasetTagsApiDelete:
mock_tag_svc.delete_tag.assert_called_once_with("tag-1")
@patch("libs.login.current_user")
def test_delete_tag_forbidden(self, mock_current_user, app):
def test_delete_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagsApi
user_obj = Mock(spec=Account)
@ -933,7 +933,7 @@ class TestDatasetTagsBindingStatusApi:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagsBindingStatusApi
@ -963,7 +963,7 @@ class TestDatasetTagBindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
@ -988,7 +988,7 @@ class TestDatasetTagBindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_bind_tags_forbidden(self, mock_current_user, app):
def test_bind_tags_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagBindingApi
mock_current_user.__class__ = Account
@ -1014,7 +1014,7 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1044,7 +1044,7 @@ class TestDatasetTagUnbindingApiPost:
self,
mock_current_user,
mock_tag_svc,
app,
app: Flask,
):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
@ -1069,7 +1069,7 @@ class TestDatasetTagUnbindingApiPost:
)
@patch("controllers.service_api.dataset.dataset.current_user")
def test_unbind_tag_forbidden(self, mock_current_user, app):
def test_unbind_tag_forbidden(self, mock_current_user, app: Flask):
from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi
mock_current_user.__class__ = Account

View File

@ -240,7 +240,7 @@ class TestDecodeJwtToken:
mock_access_mode: MagicMock,
mock_validate_token: MagicMock,
mock_validate_user: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)
@ -300,7 +300,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers, enable_site=False)
@ -325,7 +325,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, _ = self._create_app_site_enduser(db_session_with_containers)
@ -351,7 +351,7 @@ class TestDecodeJwtToken:
mock_extract: MagicMock,
mock_passport_cls: MagicMock,
mock_features: MagicMock,
app,
app: Flask,
db_session_with_containers: Session,
) -> None:
app_model, site, end_user = self._create_app_site_enduser(db_session_with_containers)

View File

@ -21,6 +21,8 @@ from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue
from extensions.ext_redis import redis_client
from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus
TenantAndAccount = tuple[Tenant, Account]
@dataclass
class TestTask:
@ -74,18 +76,18 @@ class TestTenantIsolatedTaskQueueIntegration:
return tenant, account
@pytest.fixture
def test_queue(self, test_tenant_and_account):
def test_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a generic test queue for testing."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "test_queue")
@pytest.fixture
def secondary_queue(self, test_tenant_and_account):
def secondary_queue(self, test_tenant_and_account: TenantAndAccount):
"""Create a secondary test queue for testing isolation."""
tenant, _ = test_tenant_and_account
return TenantIsolatedTaskQueue(tenant.id, "secondary_queue")
def test_queue_initialization(self, test_tenant_and_account):
def test_queue_initialization(self, test_tenant_and_account: TenantAndAccount):
"""Test queue initialization with correct key generation."""
tenant, _ = test_tenant_and_account
queue = TenantIsolatedTaskQueue(tenant.id, "test-key")
@ -95,7 +97,9 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}"
assert queue._task_key == f"tenant_test-key_task:{tenant.id}"
def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker):
def test_tenant_isolation(
self, test_tenant_and_account: TenantAndAccount, db_session_with_containers: Session, fake: Faker
):
"""Test that different tenants have isolated queues."""
tenant1, _ = test_tenant_and_account
@ -115,7 +119,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert queue1._queue == f"tenant_self_same-key_task_queue:{tenant1.id}"
assert queue2._queue == f"tenant_self_same-key_task_queue:{tenant2.id}"
def test_key_isolation(self, test_tenant_and_account):
def test_key_isolation(self, test_tenant_and_account: TenantAndAccount):
"""Test that different keys have isolated queues."""
tenant, _ = test_tenant_and_account
queue1 = TenantIsolatedTaskQueue(tenant.id, "key1")
@ -293,7 +297,7 @@ class TestTenantIsolatedTaskQueueIntegration:
assert isinstance(task, dict)
assert task["index"] == i # FIFO order
def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker):
def test_queue_operations_isolation(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""Test concurrent operations on different queues."""
tenant, _ = test_tenant_and_account
@ -436,7 +440,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
return tenant, account
def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker):
def test_legacy_string_queue_compatibility(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test compatibility with legacy queues containing only string data.
@ -466,7 +470,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"]
assert pulled_tasks == expected_order
def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker):
def test_legacy_queue_migration_scenario(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test complete migration scenario from legacy to new system.
@ -547,7 +551,7 @@ class TestTenantIsolatedTaskQueueCompatibility:
assert task["tenant_id"] == tenant.id
assert task["processing_type"] == "new_system"
def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker):
def test_legacy_queue_error_recovery(self, test_tenant_and_account: TenantAndAccount, fake: Faker):
"""
Test error recovery when legacy queue contains malformed data.

View File

@ -0,0 +1,94 @@
from __future__ import annotations
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
def _tenant_id() -> str:
return str(uuid4())
def _get_permission(session: Session, tenant_id: str) -> TenantPluginPermission | None:
session.expire_all()
stmt = select(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
return session.scalars(stmt).one_or_none()
def _count_permissions(session: Session, tenant_id: str) -> int:
stmt = select(func.count()).select_from(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id)
return session.scalar(stmt) or 0
class TestGetPermission:
"""Integration tests for PluginPermissionService.get_permission using testcontainers."""
def test_returns_permission_when_found(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
permission = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=TenantPluginPermission.InstallPermission.ADMINS,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
db_session_with_containers.add(permission)
db_session_with_containers.commit()
result = PluginPermissionService.get_permission(tenant_id)
assert result is not None
assert result.id == permission.id
assert result.tenant_id == tenant_id
assert result.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert result.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_returns_none_when_not_found(self, db_session_with_containers: Session):
result = PluginPermissionService.get_permission(_tenant_id())
assert result is None
class TestChangePermission:
"""Integration tests for PluginPermissionService.change_permission using testcontainers."""
def test_creates_new_permission_when_not_exists(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
result = PluginPermissionService.change_permission(
tenant_id,
TenantPluginPermission.InstallPermission.EVERYONE,
TenantPluginPermission.DebugPermission.EVERYONE,
)
permission = _get_permission(db_session_with_containers, tenant_id)
assert result is True
assert permission is not None
assert permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE
assert permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE
def test_updates_existing_permission(self, db_session_with_containers: Session):
tenant_id = _tenant_id()
existing = TenantPluginPermission(
tenant_id=tenant_id,
install_permission=TenantPluginPermission.InstallPermission.EVERYONE,
debug_permission=TenantPluginPermission.DebugPermission.EVERYONE,
)
db_session_with_containers.add(existing)
db_session_with_containers.commit()
result = PluginPermissionService.change_permission(
tenant_id,
TenantPluginPermission.InstallPermission.ADMINS,
TenantPluginPermission.DebugPermission.ADMINS,
)
permission = _get_permission(db_session_with_containers, tenant_id)
assert result is True
assert permission is not None
assert permission.id == existing.id
assert permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
assert _count_permissions(db_session_with_containers, tenant_id) == 1

View File

@ -7,6 +7,7 @@ from faker import Faker
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import App
from models.model import EndUser
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
@ -184,7 +185,7 @@ class TestAppGenerateService:
return app, account
def _create_test_workflow(self, db_session_with_containers: Session, app):
def _create_test_workflow(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test workflow for testing.

View File

@ -165,7 +165,7 @@ class TestMessagesCleanServiceIntegration:
return app
def _create_conversation(self, db_session_with_containers: Session, app):
def _create_conversation(self, db_session_with_containers: Session, app: App):
"""Helper to create a conversation."""
conversation = Conversation(
app_id=app.id,

View File

@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models import App, CreatorUserRole
from models.enums import ConversationFromSource
from models.model import EndUser, Message
from models.web import SavedMessage
@ -88,7 +89,7 @@ class TestSavedMessageService:
return app, account
def _create_test_end_user(self, db_session_with_containers: Session, app):
def _create_test_end_user(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test end user for testing.
@ -116,7 +117,7 @@ class TestSavedMessageService:
return end_user
def _create_test_message(self, db_session_with_containers: Session, app, user):
def _create_test_message(self, db_session_with_containers: Session, app: App, user):
"""
Helper method to create a test message for testing.
@ -199,13 +200,13 @@ class TestSavedMessageService:
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -272,13 +273,13 @@ class TestSavedMessageService:
saved_message1 = SavedMessage(
app_id=app.id,
message_id=message1.id,
created_by_role="end_user",
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
saved_message2 = SavedMessage(
app_id=app.id,
message_id=message2.id,
created_by_role="end_user",
created_by_role=CreatorUserRole.END_USER,
created_by=end_user.id,
)
@ -449,7 +450,7 @@ class TestSavedMessageService:
saved_message = SavedMessage(
app_id=app.id,
message_id=message.id,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=account.id,
)
@ -540,7 +541,9 @@ class TestSavedMessageService:
message = self._create_test_message(db_session_with_containers, app, account)
# Pre-create a saved message
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="account", created_by=account.id)
saved = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id
)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
@ -571,7 +574,9 @@ class TestSavedMessageService:
end_user = self._create_test_end_user(db_session_with_containers, app)
message = self._create_test_message(db_session_with_containers, app, end_user)
saved = SavedMessage(app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id)
saved = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id
)
db_session_with_containers.add(saved)
db_session_with_containers.commit()
@ -596,10 +601,10 @@ class TestSavedMessageService:
# Both users save the same message
saved_account = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="account", created_by=account1.id
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.ACCOUNT, created_by=account1.id
)
saved_end_user = SavedMessage(
app_id=app.id, message_id=message.id, created_by_role="end_user", created_by=end_user.id
app_id=app.id, message_id=message.id, created_by_role=CreatorUserRole.END_USER, created_by=end_user.id
)
db_session_with_containers.add_all([saved_account, saved_end_user])
db_session_with_containers.commit()

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from models import Account
from models import Account, App
from models.enums import ConversationFromSource
from models.model import Conversation, EndUser
from models.web import PinnedConversation
@ -93,7 +93,7 @@ class TestWebConversationService:
return app, account
def _create_test_end_user(self, db_session_with_containers: Session, app):
def _create_test_end_user(self, db_session_with_containers: Session, app: App):
"""
Helper method to create a test end user for testing.

View File

@ -185,7 +185,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):
@ -263,7 +263,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):
@ -312,7 +312,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
language,
@ -358,7 +358,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
):
"""
@ -398,7 +398,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
):
"""
@ -438,7 +438,7 @@ class TestActivateApi:
mock_db,
mock_revoke_token,
mock_get_invitation,
app,
app: Flask,
mock_invitation,
mock_account,
):

View File

@ -195,7 +195,7 @@ class TestEmailCodeLoginSendEmailApi:
mock_get_user,
mock_is_ip_limit,
mock_db,
app,
app: Flask,
mock_account,
language_input,
expected_language,
@ -267,7 +267,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -315,7 +315,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -431,7 +431,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""
@ -474,7 +474,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""
@ -515,7 +515,7 @@ class TestEmailCodeLoginApi:
mock_revoke_token,
mock_get_data,
mock_db,
app,
app: Flask,
mock_account,
):
"""

View File

@ -412,7 +412,7 @@ class TestLoginApi:
mock_get_invitation,
mock_is_rate_limit,
mock_db,
app,
app: Flask,
mock_account,
mock_token_pair,
):
@ -448,7 +448,7 @@ class TestLoginApi:
mock_revoke_token,
mock_get_token_data,
mock_db,
app,
app: Flask,
):
mock_get_token_data.return_value = {"email": "User@Example.com", "code": "123456"}
mock_get_account.side_effect = Unauthorized("Account is banned.")

View File

@ -74,7 +74,7 @@ class TestRefreshTokenApi:
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
def test_refresh_fails_without_token(self, mock_extract_token, app):
def test_refresh_fails_without_token(self, mock_extract_token, app: Flask):
"""
Test token refresh failure when no refresh token provided.
@ -98,7 +98,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh failure with invalid refresh token.
@ -123,7 +123,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh failure with expired refresh token.
@ -148,7 +148,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app: Flask):
"""
Test token refresh with empty string token.

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
from controllers.console import console_ns
@ -29,7 +30,7 @@ def unwrap(func):
class TestDatasourcePluginOAuthAuthorizationUrl:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -61,7 +62,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
assert response.status_code == 200
def test_get_no_oauth_config(self, app):
def test_get_no_oauth_config(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -80,7 +81,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_without_credential_id_sets_cookie(self, app):
def test_get_without_credential_id_sets_cookie(self, app: Flask):
api = DatasourcePluginOAuthAuthorizationUrl()
method = unwrap(api.get)
@ -115,7 +116,7 @@ class TestDatasourcePluginOAuthAuthorizationUrl:
class TestDatasourceOAuthCallback:
def test_callback_success_new_credential(self, app):
def test_callback_success_new_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -157,7 +158,7 @@ class TestDatasourceOAuthCallback:
assert response.status_code == 302
def test_callback_missing_context(self, app):
def test_callback_missing_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -165,7 +166,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_invalid_context(self, app):
def test_callback_invalid_context(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -180,7 +181,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(Forbidden):
method(api, "notion")
def test_callback_oauth_config_not_found(self, app):
def test_callback_oauth_config_not_found(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -202,7 +203,7 @@ class TestDatasourceOAuthCallback:
with pytest.raises(NotFound):
method(api, "notion")
def test_callback_reauthorize_existing_credential(self, app):
def test_callback_reauthorize_existing_credential(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -245,7 +246,7 @@ class TestDatasourceOAuthCallback:
assert response.status_code == 302
assert "/oauth-callback" in response.location
def test_callback_context_id_from_cookie(self, app):
def test_callback_context_id_from_cookie(self, app: Flask):
api = DatasourceOAuthCallback()
method = unwrap(api.get)
@ -289,7 +290,7 @@ class TestDatasourceOAuthCallback:
class TestDatasourceAuth:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -312,7 +313,7 @@ class TestDatasourceAuth:
assert status == 200
def test_post_invalid_credentials(self, app):
def test_post_invalid_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -334,7 +335,7 @@ class TestDatasourceAuth:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
@ -355,7 +356,7 @@ class TestDatasourceAuth:
assert status == 200
assert response["result"]
def test_post_missing_credentials(self, app):
def test_post_missing_credentials(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.post)
@ -372,7 +373,7 @@ class TestDatasourceAuth:
with pytest.raises(ValueError):
method(api, "notion")
def test_get_empty_list(self, app):
def test_get_empty_list(self, app: Flask):
api = DatasourceAuth()
method = unwrap(api.get)
@ -395,7 +396,7 @@ class TestDatasourceAuth:
class TestDatasourceAuthDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
@ -418,7 +419,7 @@ class TestDatasourceAuthDeleteApi:
assert status == 200
def test_delete_missing_credential_id(self, app):
def test_delete_missing_credential_id(self, app: Flask):
api = DatasourceAuthDeleteApi()
method = unwrap(api.post)
@ -437,7 +438,7 @@ class TestDatasourceAuthDeleteApi:
class TestDatasourceAuthUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -460,7 +461,7 @@ class TestDatasourceAuthUpdateApi:
assert status == 201
def test_update_with_credentials_none(self, app):
def test_update_with_credentials_none(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -484,7 +485,7 @@ class TestDatasourceAuthUpdateApi:
update_mock.assert_called_once()
assert status == 201
def test_update_name_only(self, app):
def test_update_name_only(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -507,7 +508,7 @@ class TestDatasourceAuthUpdateApi:
assert status == 201
def test_update_with_empty_credentials_dict(self, app):
def test_update_with_empty_credentials_dict(self, app: Flask):
api = DatasourceAuthUpdateApi()
method = unwrap(api.post)
@ -533,7 +534,7 @@ class TestDatasourceAuthUpdateApi:
class TestDatasourceAuthListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
@ -553,7 +554,7 @@ class TestDatasourceAuthListApi:
assert status == 200
def test_auth_list_empty(self, app):
def test_auth_list_empty(self, app: Flask):
api = DatasourceAuthListApi()
method = unwrap(api.get)
@ -574,7 +575,7 @@ class TestDatasourceAuthListApi:
assert status == 200
assert response["result"] == []
def test_hardcode_list_empty(self, app):
def test_hardcode_list_empty(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
@ -597,7 +598,7 @@ class TestDatasourceAuthListApi:
class TestDatasourceHardCodeAuthListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = DatasourceHardCodeAuthListApi()
method = unwrap(api.get)
@ -619,7 +620,7 @@ class TestDatasourceHardCodeAuthListApi:
class TestDatasourceAuthOauthCustomClient:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -642,7 +643,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.delete)
@ -662,7 +663,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_post_empty_payload(self, app):
def test_post_empty_payload(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -685,7 +686,7 @@ class TestDatasourceAuthOauthCustomClient:
assert status == 200
def test_post_disabled_flag(self, app):
def test_post_disabled_flag(self, app: Flask):
api = DatasourceAuthOauthCustomClient()
method = unwrap(api.post)
@ -714,7 +715,7 @@ class TestDatasourceAuthOauthCustomClient:
class TestDatasourceAuthDefaultApi:
def test_set_default_success(self, app):
def test_set_default_success(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
@ -737,7 +738,7 @@ class TestDatasourceAuthDefaultApi:
assert status == 200
def test_default_missing_id(self, app):
def test_default_missing_id(self, app: Flask):
api = DatasourceAuthDefaultApi()
method = unwrap(api.post)
@ -756,7 +757,7 @@ class TestDatasourceAuthDefaultApi:
class TestDatasourceUpdateProviderNameApi:
def test_update_name_success(self, app):
def test_update_name_success(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
@ -779,7 +780,7 @@ class TestDatasourceUpdateProviderNameApi:
assert status == 200
def test_update_name_too_long(self, app):
def test_update_name_too_long(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)
@ -799,7 +800,7 @@ class TestDatasourceUpdateProviderNameApi:
with pytest.raises(ValueError):
method(api, "notion")
def test_update_name_missing_credential_id(self, app):
def test_update_name_missing_credential_id(self, app: Flask):
api = DatasourceUpdateProviderNameApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.console import console_ns
@ -25,7 +26,7 @@ class TestDataSourceContentPreviewApi:
"credential_id": "cred-1",
}
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -66,7 +67,7 @@ class TestDataSourceContentPreviewApi:
assert status == 200
assert response == preview_result
def test_post_forbidden_non_account_user(self, app):
def test_post_forbidden_non_account_user(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -85,7 +86,7 @@ class TestDataSourceContentPreviewApi:
with pytest.raises(Forbidden):
method(api, pipeline, "node-1")
def test_post_invalid_payload(self, app):
def test_post_invalid_payload(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)
@ -108,7 +109,7 @@ class TestDataSourceContentPreviewApi:
with pytest.raises(ValueError):
method(api, pipeline, "node-1")
def test_post_without_credential_id(self, app):
def test_post_without_credential_id(self, app: Flask):
api = DataSourceContentPreviewApi()
method = unwrap(api.post)

View File

@ -2,6 +2,7 @@ import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import services
@ -58,7 +59,7 @@ class TestDatasetList:
user.is_dataset_editor = True
return user
def test_get_success_basic(self, app):
def test_get_success_basic(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -93,7 +94,7 @@ class TestDatasetList:
assert resp["total"] == 1
assert resp["data"][0]["embedding_available"] is True
def test_get_with_ids_filter(self, app):
def test_get_with_ids_filter(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -128,7 +129,7 @@ class TestDatasetList:
assert status == 200
assert resp["total"] == 2
def test_get_with_tag_ids(self, app):
def test_get_with_tag_ids(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -161,7 +162,7 @@ class TestDatasetList:
assert status == 200
def test_embedding_available_false(self, app):
def test_embedding_available_false(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -203,7 +204,7 @@ class TestDatasetList:
assert resp["data"][0]["embedding_available"] is False
def test_partial_members_permission(self, app):
def test_partial_members_permission(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.get)
@ -242,7 +243,7 @@ class TestDatasetList:
class TestDatasetListApiPost:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -290,7 +291,7 @@ class TestDatasetListApiPost:
assert status == 201
def test_post_forbidden(self, app):
def test_post_forbidden(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -310,7 +311,7 @@ class TestDatasetListApiPost:
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
def test_post_duplicate_name(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -335,7 +336,7 @@ class TestDatasetListApiPost:
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload_missing_name(self, app):
def test_post_invalid_payload_missing_name(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -343,7 +344,7 @@ class TestDatasetListApiPost:
with pytest.raises(ValueError):
method(api)
def test_post_invalid_indexing_technique(self, app):
def test_post_invalid_indexing_technique(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -356,7 +357,7 @@ class TestDatasetListApiPost:
with pytest.raises(ValueError, match="Invalid indexing technique"):
method(api)
def test_post_invalid_provider(self, app):
def test_post_invalid_provider(self, app: Flask):
api = DatasetListApi()
method = unwrap(api.post)
@ -371,7 +372,7 @@ class TestDatasetListApiPost:
class TestDatasetApiGet:
def test_get_success_basic(self, app):
def test_get_success_basic(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -427,7 +428,7 @@ class TestDatasetApiGet:
assert status == 200
assert data["embedding_available"] is True
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -448,7 +449,7 @@ class TestDatasetApiGet:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -475,7 +476,7 @@ class TestDatasetApiGet:
with pytest.raises(Forbidden, match="no access"):
method(api, dataset_id)
def test_get_high_quality_embedding_unavailable(self, app):
def test_get_high_quality_embedding_unavailable(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -530,7 +531,7 @@ class TestDatasetApiGet:
assert data["embedding_available"] is False
def test_get_partial_members_permission(self, app):
def test_get_partial_members_permission(self, app: Flask):
api = DatasetApi()
method = unwrap(api.get)
@ -590,7 +591,7 @@ class TestDatasetApiGet:
class TestDatasetApiPatch:
def test_patch_success_basic(self, app):
def test_patch_success_basic(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -659,7 +660,7 @@ class TestDatasetApiPatch:
assert status == 200
assert result["partial_member_list"] == []
def test_patch_dataset_not_found(self, app):
def test_patch_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -674,7 +675,7 @@ class TestDatasetApiPatch:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, "missing")
def test_patch_permission_denied(self, app):
def test_patch_permission_denied(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -704,7 +705,7 @@ class TestDatasetApiPatch:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_patch_partial_members_update(self, app):
def test_patch_partial_members_update(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -773,7 +774,7 @@ class TestDatasetApiPatch:
assert result["partial_member_list"] == payload["partial_member_list"]
def test_patch_clear_partial_members(self, app):
def test_patch_clear_partial_members(self, app: Flask):
api = DatasetApi()
method = unwrap(api.patch)
@ -843,7 +844,7 @@ class TestDatasetApiPatch:
class TestDatasetApiDelete:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -874,7 +875,7 @@ class TestDatasetApiDelete:
assert status == 204
assert result == {"result": "success"}
def test_delete_forbidden_no_permission(self, app):
def test_delete_forbidden_no_permission(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -893,7 +894,7 @@ class TestDatasetApiDelete:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_delete_dataset_not_found(self, app):
def test_delete_dataset_not_found(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -917,7 +918,7 @@ class TestDatasetApiDelete:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_delete_dataset_in_use(self, app):
def test_delete_dataset_in_use(self, app: Flask):
api = DatasetApi()
method = unwrap(api.delete)
@ -943,7 +944,7 @@ class TestDatasetApiDelete:
class TestDatasetUseCheckApi:
def test_get_use_check_true(self, app):
def test_get_use_check_true(self, app: Flask):
api = DatasetUseCheckApi()
method = unwrap(api.get)
@ -962,7 +963,7 @@ class TestDatasetUseCheckApi:
assert status == 200
assert result == {"is_using": True}
def test_get_use_check_false(self, app):
def test_get_use_check_false(self, app: Flask):
api = DatasetUseCheckApi()
method = unwrap(api.get)
@ -983,7 +984,7 @@ class TestDatasetUseCheckApi:
class TestDatasetQueryApi:
def test_get_queries_success(self, app):
def test_get_queries_success(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1027,7 +1028,7 @@ class TestDatasetQueryApi:
assert response["has_more"] is False
assert len(response["data"]) == 2
def test_get_queries_dataset_not_found(self, app):
def test_get_queries_dataset_not_found(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1049,7 +1050,7 @@ class TestDatasetQueryApi:
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_queries_permission_denied(self, app):
def test_get_queries_permission_denied(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1078,7 +1079,7 @@ class TestDatasetQueryApi:
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_get_queries_pagination_has_more(self, app):
def test_get_queries_pagination_has_more(self, app: Flask):
api = DatasetQueryApi()
method = unwrap(api.get)
@ -1152,7 +1153,7 @@ class TestDatasetIndexingEstimateApi:
"dataset_id": None,
}
def test_post_success_upload_file(self, app):
def test_post_success_upload_file(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
@ -1193,7 +1194,7 @@ class TestDatasetIndexingEstimateApi:
assert status == 200
assert response == {"tokens": 100}
def test_post_file_not_found(self, app):
def test_post_file_not_found(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
@ -1223,7 +1224,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(NotFound):
method(api)
def test_post_llm_bad_request_error(self, app):
def test_post_llm_bad_request_error(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1258,7 +1259,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_provider_token_not_init(self, app):
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1293,7 +1294,7 @@ class TestDatasetIndexingEstimateApi:
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_generic_exception(self, app):
def test_post_generic_exception(self, app: Flask):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
@ -1330,7 +1331,7 @@ class TestDatasetIndexingEstimateApi:
class TestDatasetRelatedAppListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1368,7 +1369,7 @@ class TestDatasetRelatedAppListApi:
assert response["total"] == 2
assert response["data"] == [app1, app2]
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1386,7 +1387,7 @@ class TestDatasetRelatedAppListApi:
with pytest.raises(NotFound):
method(api, "dataset-1")
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1410,7 +1411,7 @@ class TestDatasetRelatedAppListApi:
with pytest.raises(Forbidden):
method(api, "dataset-1")
def test_get_filters_none_apps(self, app):
def test_get_filters_none_apps(self, app: Flask):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
@ -1449,7 +1450,7 @@ class TestDatasetRelatedAppListApi:
class TestDatasetIndexingStatusApi:
def test_get_success_with_documents(self, app):
def test_get_success_with_documents(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1490,7 +1491,7 @@ class TestDatasetIndexingStatusApi:
assert item["completed_segments"] == 3
assert item["total_segments"] == 3
def test_get_success_no_documents(self, app):
def test_get_success_no_documents(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1510,7 +1511,7 @@ class TestDatasetIndexingStatusApi:
assert status == 200
assert response == {"data": []}
def test_segment_counts_different_values(self, app):
def test_segment_counts_different_values(self, app: Flask):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
@ -1550,7 +1551,7 @@ class TestDatasetIndexingStatusApi:
class TestDatasetApiKeyApi:
def test_get_api_keys_success(self, app):
def test_get_api_keys_success(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.get)
@ -1587,7 +1588,7 @@ class TestDatasetApiKeyApi:
assert response["data"][1]["id"] == "key-2"
assert response["data"][1]["token"] == "ds-def"
def test_post_create_api_key_success(self, app):
def test_post_create_api_key_success(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.post)
@ -1632,7 +1633,7 @@ class TestDatasetApiKeyApi:
assert response["type"] == "dataset"
assert response["created_at"] is not None
def test_post_exceed_max_keys(self, app):
def test_post_exceed_max_keys(self, app: Flask):
api = DatasetApiKeyApi()
method = unwrap(api.post)
@ -1658,7 +1659,7 @@ class TestDatasetApiKeyApi:
class TestDatasetApiDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
@ -1688,7 +1689,7 @@ class TestDatasetApiDeleteApi:
assert status == 204
assert response["result"] == "success"
def test_delete_key_not_found(self, app):
def test_delete_key_not_found(self, app: Flask):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
@ -1708,7 +1709,7 @@ class TestDatasetApiDeleteApi:
class TestDatasetEnableApiApi:
def test_enable_api(self, app):
def test_enable_api(self, app: Flask):
api = DatasetEnableApiApi()
method = unwrap(api.post)
@ -1724,7 +1725,7 @@ class TestDatasetEnableApiApi:
assert status == 200
assert response["result"] == "success"
def test_disable_api(self, app):
def test_disable_api(self, app: Flask):
api = DatasetEnableApiApi()
method = unwrap(api.post)
@ -1742,7 +1743,7 @@ class TestDatasetEnableApiApi:
class TestDatasetApiBaseUrlApi:
def test_get_api_base_url_from_config(self, app):
def test_get_api_base_url_from_config(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1757,7 +1758,7 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "https://example.com/v1"
def test_get_api_base_url_from_request(self, app):
def test_get_api_base_url_from_request(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1772,7 +1773,7 @@ class TestDatasetApiBaseUrlApi:
assert response["api_base_url"] == "http://localhost:5000/v1"
def test_get_api_base_url_no_double_v1(self, app):
def test_get_api_base_url_no_double_v1(self, app: Flask):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
@ -1789,7 +1790,7 @@ class TestDatasetApiBaseUrlApi:
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRetrievalSettingApi()
method = unwrap(api.get)
@ -1810,7 +1811,7 @@ class TestDatasetRetrievalSettingApi:
class TestDatasetRetrievalSettingMockApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetRetrievalSettingMockApi()
method = unwrap(api.get)
@ -1827,7 +1828,7 @@ class TestDatasetRetrievalSettingMockApi:
class TestDatasetErrorDocs:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetErrorDocs()
method = unwrap(api.get)
@ -1850,7 +1851,7 @@ class TestDatasetErrorDocs:
assert status == 200
assert response["total"] == 1
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetErrorDocs()
method = unwrap(api.get)
@ -1866,7 +1867,7 @@ class TestDatasetErrorDocs:
class TestDatasetPermissionUserListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
@ -1897,7 +1898,7 @@ class TestDatasetPermissionUserListApi:
assert status == 200
assert response["data"] == users
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
@ -1923,7 +1924,7 @@ class TestDatasetPermissionUserListApi:
class TestDatasetAutoDisableLogApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
@ -1946,7 +1947,7 @@ class TestDatasetAutoDisableLogApi:
assert status == 200
assert response == logs
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -239,7 +240,7 @@ class TestDatasetDocumentListApi:
assert "documents" in response
def test_post_forbidden(self, app):
def test_post_forbidden(self, app: Flask):
api = DatasetDocumentListApi()
method = unwrap(api.post)
@ -395,7 +396,7 @@ class TestDocumentDownloadApi:
class TestDocumentProcessingApi:
def test_processing_forbidden_when_not_editor(self, app):
def test_processing_forbidden_when_not_editor(self, app: Flask):
api = DocumentProcessingApi()
method = unwrap(api.patch)
@ -1185,7 +1186,7 @@ class TestDocumentPermissionCases:
"preview": [],
}
def test_document_tenant_mismatch(self, app):
def test_document_tenant_mismatch(self, app: Flask):
api = DocumentApi()
method = unwrap(api.get)
@ -1253,7 +1254,7 @@ class TestDocumentPermissionCases:
assert status == 200
assert response["mode"] == "custom"
def test_process_rule_permission_denied(self, app):
def test_process_rule_permission_denied(self, app: Flask):
api = GetProcessRuleApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound
import services
@ -82,7 +83,7 @@ def test_get_segment_with_summary(monkeypatch):
class TestDatasetDocumentSegmentListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -132,7 +133,7 @@ class TestDatasetDocumentSegmentListApi:
assert status == 200
def test_get_dataset_not_found(self, app):
def test_get_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -150,7 +151,7 @@ class TestDatasetDocumentSegmentListApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_get_permission_denied(self, app):
def test_get_permission_denied(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -176,7 +177,7 @@ class TestDatasetDocumentSegmentListApi:
class TestDatasetDocumentSegmentApi:
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -221,7 +222,7 @@ class TestDatasetDocumentSegmentApi:
assert status == 200
assert response["result"] == "success"
def test_patch_document_indexing_in_progress(self, app):
def test_patch_document_indexing_in_progress(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -264,7 +265,7 @@ class TestDatasetDocumentSegmentApi:
with pytest.raises(InvalidActionError):
method(api, "ds-1", "doc-1", "disable")
def test_patch_llm_bad_request(self, app):
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -308,7 +309,7 @@ class TestDatasetDocumentSegmentApi:
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1", "enable")
def test_patch_provider_token_not_init(self, app):
def test_patch_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentApi()
method = unwrap(api.patch)
@ -354,7 +355,7 @@ class TestDatasetDocumentSegmentApi:
class TestDatasetDocumentSegmentAddApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -413,7 +414,7 @@ class TestDatasetDocumentSegmentAddApi:
assert status == 200
assert response["data"]["id"] == "seg-1"
def test_post_llm_bad_request(self, app):
def test_post_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -452,7 +453,7 @@ class TestDatasetDocumentSegmentAddApi:
with pytest.raises(ProviderNotInitializeError):
method(api, "ds-1", "doc-1")
def test_post_provider_token_not_init(self, app):
def test_post_provider_token_not_init(self, app: Flask):
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -493,7 +494,7 @@ class TestDatasetDocumentSegmentAddApi:
class TestDatasetDocumentSegmentUpdateApi:
def test_patch_success(self, app):
def test_patch_success(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
@ -551,7 +552,7 @@ class TestDatasetDocumentSegmentUpdateApi:
assert status == 200
assert "data" in response
def test_patch_llm_bad_request(self, app):
def test_patch_llm_bad_request(self, app: Flask):
api = DatasetDocumentSegmentUpdateApi()
method = unwrap(api.patch)
@ -596,7 +597,7 @@ class TestDatasetDocumentSegmentUpdateApi:
class TestDatasetDocumentSegmentBatchImportApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -638,7 +639,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
assert status == 200
assert response["job_status"] == "waiting"
def test_post_dataset_not_found(self, app):
def test_post_dataset_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -659,7 +660,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_document_not_found(self, app):
def test_post_document_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -684,7 +685,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_upload_file_not_found(self, app):
def test_post_upload_file_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -713,7 +714,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_post_invalid_file_type(self, app):
def test_post_invalid_file_type(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -745,7 +746,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
with pytest.raises(ValueError):
method(api, "ds-1", "doc-1")
def test_post_async_task_failure(self, app):
def test_post_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -783,7 +784,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
assert status == 500
assert "error" in response
def test_get_job_not_found_in_redis(self, app):
def test_get_job_not_found_in_redis(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)
@ -799,7 +800,7 @@ class TestDatasetDocumentSegmentBatchImportApi:
class TestChildChunkAddApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
@ -852,7 +853,7 @@ class TestChildChunkAddApi:
assert status == 200
assert response["data"]["id"] == "cc-1"
def test_post_child_chunk_indexing_error(self, app):
def test_post_child_chunk_indexing_error(self, app: Flask):
api = ChildChunkAddApi()
method = unwrap(api.post)
@ -897,7 +898,7 @@ class TestChildChunkAddApi:
class TestChildChunkUpdateApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
@ -941,7 +942,7 @@ class TestChildChunkUpdateApi:
assert status == 204
assert response["result"] == "success"
def test_delete_child_chunk_index_error(self, app):
def test_delete_child_chunk_index_error(self, app: Flask):
api = ChildChunkUpdateApi()
method = unwrap(api.delete)
@ -984,7 +985,7 @@ class TestChildChunkUpdateApi:
class TestSegmentListAdvancedCases:
def test_segment_list_with_keyword_filter(self, app):
def test_segment_list_with_keyword_filter(self, app: Flask):
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1035,7 +1036,7 @@ class TestSegmentListAdvancedCases:
assert status == 200
assert response["total"] == 1
def test_segment_list_permission_denied(self, app):
def test_segment_list_permission_denied(self, app: Flask):
"""Test segment list with permission denied"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1058,7 +1059,7 @@ class TestSegmentListAdvancedCases:
with pytest.raises(Forbidden):
method(api, "ds-1", "doc-1")
def test_segment_list_dataset_not_found(self, app):
def test_segment_list_dataset_not_found(self, app: Flask):
"""Test segment list with dataset not found"""
api = DatasetDocumentSegmentListApi()
method = unwrap(api.get)
@ -1079,7 +1080,7 @@ class TestSegmentListAdvancedCases:
class TestSegmentOperationCases:
def test_segment_add_with_provider_token_error(self, app):
def test_segment_add_with_provider_token_error(self, app: Flask):
"""Test segment add with provider token not initialized"""
api = DatasetDocumentSegmentAddApi()
method = unwrap(api.post)
@ -1117,7 +1118,7 @@ class TestSegmentOperationCases:
with pytest.raises(ProviderTokenNotInitError):
method(api, "ds-1", "doc-1")
def test_batch_import_with_document_not_found(self, app):
def test_batch_import_with_document_not_found(self, app: Flask):
"""Test batch import with document not found"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1146,7 +1147,7 @@ class TestSegmentOperationCases:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_batch_import_with_invalid_file(self, app):
def test_batch_import_with_invalid_file(self, app: Flask):
"""Test batch import with invalid file type"""
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1181,7 +1182,7 @@ class TestSegmentOperationCases:
with pytest.raises(NotFound):
method(api, "ds-1", "doc-1")
def test_batch_import_with_async_task_failure(self, app):
def test_batch_import_with_async_task_failure(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.post)
@ -1226,7 +1227,7 @@ class TestSegmentOperationCases:
assert status == 500
assert "error" in response
def test_batch_import_get_job_not_found(self, app):
def test_batch_import_get_job_not_found(self, app: Flask):
api = DatasetDocumentSegmentBatchImportApi()
method = unwrap(api.get)

View File

@ -57,7 +57,7 @@ def mock_auth(monkeypatch, current_user):
class TestExternalApiTemplateListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.get)
@ -78,7 +78,7 @@ class TestExternalApiTemplateListApi:
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
def test_post_forbidden(self, app, current_user):
def test_post_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
@ -93,7 +93,7 @@ class TestExternalApiTemplateListApi:
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
def test_post_duplicate_name(self, app: Flask):
api = ExternalApiTemplateListApi()
method = unwrap(api.post)
@ -114,7 +114,7 @@ class TestExternalApiTemplateListApi:
class TestExternalApiTemplateApi:
def test_get_not_found(self, app):
def test_get_not_found(self, app: Flask):
api = ExternalApiTemplateApi()
method = unwrap(api.get)
@ -129,7 +129,7 @@ class TestExternalApiTemplateApi:
with pytest.raises(NotFound):
method(api, "api-id")
def test_delete_forbidden(self, app, current_user):
def test_delete_forbidden(self, app: Flask, current_user):
current_user.has_edit_permission = False
current_user.is_dataset_operator = False
@ -142,7 +142,7 @@ class TestExternalApiTemplateApi:
class TestExternalApiUseCheckApi:
def test_get_scopes_usage_check_to_current_tenant(self, app):
def test_get_scopes_usage_check_to_current_tenant(self, app: Flask):
api = ExternalApiUseCheckApi()
method = unwrap(api.get)
@ -162,7 +162,7 @@ class TestExternalApiUseCheckApi:
class TestExternalDatasetCreateApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
@ -206,7 +206,7 @@ class TestExternalDatasetCreateApi:
assert status == 201
def test_create_forbidden(self, app, current_user):
def test_create_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
api = ExternalDatasetCreateApi()
method = unwrap(api.post)
@ -226,7 +226,7 @@ class TestExternalDatasetCreateApi:
class TestExternalKnowledgeHitTestingApi:
def test_hit_testing_dataset_not_found(self, app):
def test_hit_testing_dataset_not_found(self, app: Flask):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
@ -241,7 +241,7 @@ class TestExternalKnowledgeHitTestingApi:
with pytest.raises(NotFound):
method(api, "dataset-id")
def test_hit_testing_success(self, app):
def test_hit_testing_success(self, app: Flask):
api = ExternalKnowledgeHitTestingApi()
method = unwrap(api.post)
@ -266,7 +266,7 @@ class TestExternalKnowledgeHitTestingApi:
class TestBedrockRetrievalApi:
def test_bedrock_retrieval(self, app):
def test_bedrock_retrieval(self, app: Flask):
api = BedrockRetrievalApi()
method = unwrap(api.post)

View File

@ -269,7 +269,7 @@ class TestDatasetMetadataApi:
class TestDatasetMetadataBuiltInFieldApi:
def test_get_built_in_fields(self, app):
def test_get_built_in_fields(self, app: Flask):
api = DatasetMetadataBuiltInFieldApi()
method = unwrap(api.get)

View File

@ -1,6 +1,8 @@
from datetime import datetime
from unittest.mock import MagicMock, patch
from flask import Flask
import controllers.console.explore.banner as banner_module
from models.enums import BannerStatus
@ -12,7 +14,7 @@ def unwrap(func):
class TestBannerApi:
def test_get_banners_with_requested_language(self, app):
def test_get_banners_with_requested_language(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)
@ -41,7 +43,7 @@ class TestBannerApi:
}
]
def test_get_banners_fallback_to_en_us(self, app):
def test_get_banners_fallback_to_en_us(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)
@ -76,7 +78,7 @@ class TestBannerApi:
}
]
def test_get_banners_default_language_en_us(self, app):
def test_get_banners_default_language_en_us(self, app: Flask):
api = banner_module.BannerApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import InternalServerError, NotFound
import controllers.console.explore.message as module
@ -54,7 +55,7 @@ def make_message():
class TestMessageListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -96,7 +97,7 @@ class TestMessageListApi:
with pytest.raises(NotChatAppError):
method(installed_app)
def test_conversation_not_exists(self, app):
def test_conversation_not_exists(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -118,7 +119,7 @@ class TestMessageListApi:
with pytest.raises(NotFound):
method(installed_app)
def test_first_message_not_exists(self, app):
def test_first_message_not_exists(self, app: Flask):
api = module.MessageListApi()
method = unwrap(api.get)
@ -142,7 +143,7 @@ class TestMessageListApi:
class TestMessageFeedbackApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
@ -161,7 +162,7 @@ class TestMessageFeedbackApi:
assert result["result"] == "success"
def test_message_not_exists(self, app):
def test_message_not_exists(self, app: Flask):
api = module.MessageFeedbackApi()
method = unwrap(api.post)
@ -182,7 +183,7 @@ class TestMessageFeedbackApi:
class TestMessageMoreLikeThisApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -221,7 +222,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(NotCompletionAppError):
method(installed_app, "mid")
def test_more_like_this_disabled(self, app):
def test_more_like_this_disabled(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -243,7 +244,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(AppMoreLikeThisDisabledError):
method(installed_app, "mid")
def test_message_not_exists_more_like_this(self, app):
def test_message_not_exists_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -265,7 +266,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(NotFound):
method(installed_app, "mid")
def test_provider_not_init_more_like_this(self, app):
def test_provider_not_init_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -287,7 +288,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderNotInitializeError):
method(installed_app, "mid")
def test_quota_exceeded_more_like_this(self, app):
def test_quota_exceeded_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -309,7 +310,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderQuotaExceededError):
method(installed_app, "mid")
def test_model_not_support_more_like_this(self, app):
def test_model_not_support_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -331,7 +332,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method(installed_app, "mid")
def test_invoke_error_more_like_this(self, app):
def test_invoke_error_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)
@ -353,7 +354,7 @@ class TestMessageMoreLikeThisApi:
with pytest.raises(CompletionRequestError):
method(installed_app, "mid")
def test_unexpected_error_more_like_this(self, app):
def test_unexpected_error_more_like_this(self, app: Flask):
api = module.MessageMoreLikeThisApi()
method = unwrap(api.get)

View File

@ -1,5 +1,7 @@
from unittest.mock import MagicMock, patch
from flask import Flask
import controllers.console.explore.recommended_app as module
from models.model import AppMode, IconType
@ -11,7 +13,7 @@ def unwrap(func):
class TestRecommendedAppListApi:
def test_get_with_language_param(self, app):
def test_get_with_language_param(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -31,7 +33,7 @@ class TestRecommendedAppListApi:
service_mock.assert_called_once_with("en-US")
assert result == result_data
def test_get_fallback_to_user_language(self, app):
def test_get_fallback_to_user_language(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -51,7 +53,7 @@ class TestRecommendedAppListApi:
service_mock.assert_called_once_with("fr-FR")
assert result == result_data
def test_get_fallback_to_default_language(self, app):
def test_get_fallback_to_default_language(self, app: Flask):
api = module.RecommendedAppListApi()
method = unwrap(api.get)
@ -73,7 +75,7 @@ class TestRecommendedAppListApi:
class TestRecommendedAppApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.RecommendedAppApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ from unittest.mock import MagicMock, PropertyMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
import controllers.console.explore.saved_message as module
@ -42,7 +43,7 @@ def payload_patch():
class TestSavedMessageListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = module.SavedMessageListApi()
method = unwrap(api.get)

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import controllers.console.explore.trial as module
@ -88,7 +89,7 @@ def valid_parameters():
class TestTrialAppWorkflowRunApi:
def test_not_workflow_app(self, app):
def test_not_workflow_app(self, app: Flask):
api = module.TrialAppWorkflowRunApi()
method = unwrap(api.post)
@ -224,7 +225,7 @@ class TestTrialAppWorkflowRunApi:
class TestTrialChatApi:
def test_not_chat_app(self, app):
def test_not_chat_app(self, app: Flask):
api = module.TrialChatApi()
method = unwrap(api.post)
@ -408,7 +409,7 @@ class TestTrialChatApi:
class TestTrialCompletionApi:
def test_not_completion_app(self, app):
def test_not_completion_app(self, app: Flask):
api = module.TrialCompletionApi()
method = unwrap(api.post)
@ -560,7 +561,7 @@ class TestTrialCompletionApi:
class TestTrialMessageSuggestedQuestionApi:
def test_not_chat_app(self, app):
def test_not_chat_app(self, app: Flask):
api = module.TrialMessageSuggestedQuestionApi()
method = unwrap(api.get)
@ -952,7 +953,7 @@ class TestTrialAppWorkflowTaskStopApi:
class TestTrialSitApi:
def test_no_site(self, app):
def test_no_site(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
app_model = MagicMock()
@ -963,7 +964,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_archived_tenant(self, app):
def test_archived_tenant(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)
@ -978,7 +979,7 @@ class TestTrialSitApi:
with pytest.raises(Forbidden):
method(api, app_model)
def test_success(self, app):
def test_success(self, app: Flask):
api = module.TrialSitApi()
method = unwrap(api.get)

View File

@ -73,7 +73,7 @@ def payload_patch():
class TestTagListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = TagListApi()
method = unwrap(api.get)
@ -124,7 +124,7 @@ class TestTagListApi:
assert result["name"] == "test-tag"
assert result["binding_count"] == "0"
def test_post_forbidden(self, app, readonly_user, payload_patch):
def test_post_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagListApi()
method = unwrap(api.post)
@ -170,7 +170,7 @@ class TestTagUpdateDeleteApi:
assert status == 200
assert result["binding_count"] == "3"
def test_patch_forbidden(self, app, readonly_user, payload_patch):
def test_patch_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagUpdateDeleteApi()
method = unwrap(api.patch)
@ -231,7 +231,7 @@ class TestTagBindingCollectionApi:
assert status == 200
assert result["result"] == "success"
def test_create_forbidden(self, app, readonly_user, payload_patch):
def test_create_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagBindingCollectionApi()
method = unwrap(api.post)
@ -275,7 +275,7 @@ class TestTagBindingRemoveApi:
assert status == 200
assert result["result"] == "success"
def test_remove_forbidden(self, app, readonly_user, payload_patch):
def test_remove_forbidden(self, app: Flask, readonly_user, payload_patch):
api = TagBindingRemoveApi()
method = unwrap(api.post)

View File

@ -82,7 +82,7 @@ def mock_file_service(mock_db):
class TestFileApiGet:
def test_get_upload_config(self, app):
def test_get_upload_config(self, app: Flask):
api = FileApi()
get_method = unwrap(api.get)
@ -290,7 +290,7 @@ class TestFilePreviewApi:
class TestFileSupportTypeApi:
def test_get_supported_types(self, app):
def test_get_supported_types(self, app: Flask):
api = FileSupportTypeApi()
get_method = unwrap(api.get)

View File

@ -58,7 +58,7 @@ class TestChangeEmailSend:
mock_get_change_data,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
@ -107,7 +107,7 @@ class TestChangeEmailSend:
mock_get_change_data,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq: a phase-1 token must not unlock the new-email send step."""
from controllers.console.auth.error import InvalidTokenError
@ -155,7 +155,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("user@example.com", "acc2")
@ -214,7 +214,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("old@example.com", "acc"), None)
@ -267,7 +267,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A token whose phase marker is a string but not a known transition must be rejected."""
from controllers.console.auth.error import InvalidTokenError
@ -316,7 +316,7 @@ class TestChangeEmailValidity:
mock_reset_rate,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A token minted without a phase marker (e.g. a hand-crafted token) must not validate."""
from controllers.console.auth.error import InvalidTokenError
@ -366,7 +366,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
@ -418,7 +418,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""GHSA-4q3w-q5mc-45rq PoC: phase-1 token must not be usable against /reset."""
from controllers.console.auth.error import InvalidTokenError
@ -471,7 +471,7 @@ class TestChangeEmailReset:
mock_send_notify,
mock_current_account,
mock_db,
app,
app: Flask,
):
"""A verified token for address A must not be replayed to change to address B."""
from controllers.console.auth.error import InvalidTokenError
@ -547,7 +547,7 @@ class TestAccountServiceSendChangeEmailEmail:
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.BillingService.update_account_deletion_feedback")
def test_should_normalize_feedback_email(self, mock_update, mock_db, app):
def test_should_normalize_feedback_email(self, mock_update, mock_db, app: Flask):
with app.test_request_context(
"/account/delete/feedback",
method="POST",
@ -563,7 +563,7 @@ class TestCheckEmailUnique:
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app):
def test_should_normalize_email(self, mock_is_freeze, mock_check_unique, mock_db, app: Flask):
mock_is_freeze.return_value = False
mock_check_unique.return_value = True

View File

@ -43,7 +43,7 @@ class TestMemberInviteEmailApi:
mock_current_account,
mock_invite_member,
mock_get_features,
app,
app: Flask,
):
mock_get_features.return_value = _build_feature_flags()
mock_invite_member.return_value = "token-abc"

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.console import console_ns
@ -41,7 +42,7 @@ def unwrap(func):
class TestAccountInitApi:
def test_init_success(self, app):
def test_init_success(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
@ -64,7 +65,7 @@ class TestAccountInitApi:
assert resp["result"] == "success"
def test_init_already_initialized(self, app):
def test_init_already_initialized(self, app: Flask):
api = AccountInitApi()
method = unwrap(api.post)
@ -79,7 +80,7 @@ class TestAccountInitApi:
class TestAccountProfileApi:
def test_get_profile_success(self, app):
def test_get_profile_success(self, app: Flask):
api = AccountProfileApi()
method = unwrap(api.get)
@ -140,7 +141,7 @@ class TestAccountUpdateApis:
class TestAccountAvatarApiGet:
"""GET /account/avatar must not sign arbitrary upload_file IDs (IDOR)."""
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app):
def test_get_avatar_signed_url_when_upload_owned_by_current_account(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -172,7 +173,7 @@ class TestAccountAvatarApiGet:
assert result == {"avatar_url": "https://signed/example"}
sign_mock.assert_called_once_with(upload_file_id=file_id)
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app):
def test_get_avatar_not_found_when_upload_created_by_other_account_same_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -204,7 +205,7 @@ class TestAccountAvatarApiGet:
sign_mock.assert_not_called()
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app):
def test_get_avatar_not_found_when_upload_belongs_to_other_tenant(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -236,7 +237,7 @@ class TestAccountAvatarApiGet:
sign_mock.assert_not_called()
def test_get_avatar_https_pass_through_without_signing(self, app):
def test_get_avatar_https_pass_through_without_signing(self, app: Flask):
api = AccountAvatarApi()
method = unwrap(api.get)
@ -263,7 +264,7 @@ class TestAccountAvatarApiGet:
class TestAccountPasswordApi:
def test_password_success(self, app):
def test_password_success(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
@ -292,7 +293,7 @@ class TestAccountPasswordApi:
assert result["id"] == "u1"
def test_password_wrong_current(self, app):
def test_password_wrong_current(self, app: Flask):
api = AccountPasswordApi()
method = unwrap(api.post)
@ -317,7 +318,7 @@ class TestAccountPasswordApi:
class TestAccountIntegrateApi:
def test_get_integrates(self, app):
def test_get_integrates(self, app: Flask):
api = AccountIntegrateApi()
method = unwrap(api.get)
@ -336,7 +337,7 @@ class TestAccountIntegrateApi:
class TestAccountDeleteApi:
def test_delete_verify_success(self, app):
def test_delete_verify_success(self, app: Flask):
api = AccountDeleteVerifyApi()
method = unwrap(api.get)
@ -358,7 +359,7 @@ class TestAccountDeleteApi:
assert result["result"] == "success"
def test_delete_invalid_code(self, app):
def test_delete_invalid_code(self, app: Flask):
api = AccountDeleteApi()
method = unwrap(api.post)
@ -379,7 +380,7 @@ class TestAccountDeleteApi:
class TestChangeEmailApis:
def test_check_email_code_invalid(self, app):
def test_check_email_code_invalid(self, app: Flask):
api = ChangeEmailCheckApi()
method = unwrap(api.post)
@ -405,7 +406,7 @@ class TestChangeEmailApis:
with pytest.raises(EmailCodeError):
method(api)
def test_reset_email_already_used(self, app):
def test_reset_email_already_used(self, app: Flask):
api = ChangeEmailResetApi()
method = unwrap(api.post)
@ -427,7 +428,7 @@ class TestChangeEmailApis:
class TestCheckEmailUniqueApi:
def test_email_unique_success(self, app):
def test_email_unique_success(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)
@ -448,7 +449,7 @@ class TestCheckEmailUniqueApi:
assert result["result"] == "success"
def test_email_in_freeze(self, app):
def test_email_in_freeze(self, app: Flask):
api = CheckEmailUnique()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.error import AccountNotFound
from controllers.console.workspace.agent_providers import (
@ -16,7 +17,7 @@ def unwrap(func):
class TestAgentProviderListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -39,7 +40,7 @@ class TestAgentProviderListApi:
assert result == providers
def test_get_empty_list(self, app):
def test_get_empty_list(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -61,7 +62,7 @@ class TestAgentProviderListApi:
assert result == []
def test_get_account_not_found(self, app):
def test_get_account_not_found(self, app: Flask):
api = AgentProviderListApi()
method = unwrap(api.get)
@ -77,7 +78,7 @@ class TestAgentProviderListApi:
class TestAgentProviderApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
@ -101,7 +102,7 @@ class TestAgentProviderApi:
assert result == provider_data
def test_get_provider_not_found(self, app):
def test_get_provider_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)
@ -124,7 +125,7 @@ class TestAgentProviderApi:
assert result is None
def test_get_account_not_found(self, app):
def test_get_account_not_found(self, app: Flask):
api = AgentProviderApi()
method = unwrap(api.get)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console import console_ns
from controllers.console.workspace.endpoint import (
@ -39,7 +40,7 @@ def patch_current_account(user_and_tenant):
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointCollectionApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -57,7 +58,7 @@ class TestEndpointCollectionApi:
assert result["success"] is True
def test_create_permission_denied(self, app):
def test_create_permission_denied(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -77,7 +78,7 @@ class TestEndpointCollectionApi:
with pytest.raises(ValueError):
method(api)
def test_create_validation_error(self, app):
def test_create_validation_error(self, app: Flask):
api = EndpointCollectionApi()
method = unwrap(api.post)
@ -96,7 +97,7 @@ class TestEndpointCollectionApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointCreateApi:
def test_create_success(self, app):
def test_create_success(self, app: Flask):
api = DeprecatedEndpointCreateApi()
method = unwrap(api.post)
@ -117,7 +118,7 @@ class TestDeprecatedEndpointCreateApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListApi:
def test_list_success(self, app):
def test_list_success(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
@ -130,7 +131,7 @@ class TestEndpointListApi:
assert "endpoints" in result
assert len(result["endpoints"]) == 1
def test_list_invalid_query(self, app):
def test_list_invalid_query(self, app: Flask):
api = EndpointListApi()
method = unwrap(api.get)
@ -143,7 +144,7 @@ class TestEndpointListApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointListForSinglePluginApi:
def test_list_for_plugin_success(self, app):
def test_list_for_plugin_success(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
@ -158,7 +159,7 @@ class TestEndpointListForSinglePluginApi:
assert "endpoints" in result
def test_list_for_plugin_missing_param(self, app):
def test_list_for_plugin_missing_param(self, app: Flask):
api = EndpointListForSinglePluginApi()
method = unwrap(api.get)
@ -171,7 +172,7 @@ class TestEndpointListForSinglePluginApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointItemApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
@ -187,7 +188,7 @@ class TestEndpointItemApi:
assert result["success"] is True
mock_delete.assert_called_once_with(tenant_id="t1", user_id="u1", endpoint_id="e1")
def test_delete_service_failure(self, app):
def test_delete_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.delete)
@ -199,7 +200,7 @@ class TestEndpointItemApi:
assert result["success"] is False
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -226,7 +227,7 @@ class TestEndpointItemApi:
settings={"x": 1},
)
def test_update_validation_error(self, app):
def test_update_validation_error(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -238,7 +239,7 @@ class TestEndpointItemApi:
with pytest.raises(ValueError):
method(api, "e1")
def test_update_service_failure(self, app):
def test_update_service_failure(self, app: Flask):
api = EndpointItemApi()
method = unwrap(api.patch)
@ -258,7 +259,7 @@ class TestEndpointItemApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointDeleteApi:
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -272,7 +273,7 @@ class TestDeprecatedEndpointDeleteApi:
assert result["success"] is True
def test_delete_invalid_payload(self, app):
def test_delete_invalid_payload(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -282,7 +283,7 @@ class TestDeprecatedEndpointDeleteApi:
with pytest.raises(ValueError):
method(api)
def test_delete_service_failure(self, app):
def test_delete_service_failure(self, app: Flask):
api = DeprecatedEndpointDeleteApi()
method = unwrap(api.post)
@ -299,7 +300,7 @@ class TestDeprecatedEndpointDeleteApi:
@pytest.mark.usefixtures("patch_current_account")
class TestDeprecatedEndpointUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -317,7 +318,7 @@ class TestDeprecatedEndpointUpdateApi:
assert result["success"] is True
def test_update_validation_error(self, app):
def test_update_validation_error(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -329,7 +330,7 @@ class TestDeprecatedEndpointUpdateApi:
with pytest.raises(ValueError):
method(api)
def test_update_service_failure(self, app):
def test_update_service_failure(self, app: Flask):
api = DeprecatedEndpointUpdateApi()
method = unwrap(api.post)
@ -380,7 +381,7 @@ class TestEndpointRouteMetadata:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointEnableApi:
def test_enable_success(self, app):
def test_enable_success(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -394,7 +395,7 @@ class TestEndpointEnableApi:
assert result["success"] is True
def test_enable_invalid_payload(self, app):
def test_enable_invalid_payload(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -404,7 +405,7 @@ class TestEndpointEnableApi:
with pytest.raises(ValueError):
method(api)
def test_enable_service_failure(self, app):
def test_enable_service_failure(self, app: Flask):
api = EndpointEnableApi()
method = unwrap(api.post)
@ -421,7 +422,7 @@ class TestEndpointEnableApi:
@pytest.mark.usefixtures("patch_current_account")
class TestEndpointDisableApi:
def test_disable_success(self, app):
def test_disable_success(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)
@ -435,7 +436,7 @@ class TestEndpointDisableApi:
assert result["success"] is True
def test_disable_invalid_payload(self, app):
def test_disable_invalid_payload(self, app: Flask):
api = EndpointDisableApi()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import HTTPException
import services
@ -34,7 +35,7 @@ def unwrap(func):
class TestMemberListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = MemberListApi()
method = unwrap(api.get)
@ -59,7 +60,7 @@ class TestMemberListApi:
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
def test_get_no_tenant(self, app: Flask):
api = MemberListApi()
method = unwrap(api.get)
@ -74,7 +75,7 @@ class TestMemberListApi:
class TestMemberInviteEmailApi:
def test_invite_success(self, app):
def test_invite_success(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -101,7 +102,7 @@ class TestMemberInviteEmailApi:
assert status == 201
assert result["result"] == "success"
def test_invite_limit_exceeded(self, app):
def test_invite_limit_exceeded(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -123,7 +124,7 @@ class TestMemberInviteEmailApi:
with pytest.raises(WorkspaceMembersLimitExceeded):
method(api)
def test_invite_already_member(self, app):
def test_invite_already_member(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -151,7 +152,7 @@ class TestMemberInviteEmailApi:
assert result["invitation_results"][0]["status"] == "success"
def test_invite_invalid_role(self, app):
def test_invite_invalid_role(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -166,7 +167,7 @@ class TestMemberInviteEmailApi:
assert status == 400
assert result["code"] == "invalid-role"
def test_invite_generic_exception(self, app):
def test_invite_generic_exception(self, app: Flask):
api = MemberInviteEmailApi()
method = unwrap(api.post)
@ -196,7 +197,7 @@ class TestMemberInviteEmailApi:
class TestMemberCancelInviteApi:
def test_cancel_success(self, app):
def test_cancel_success(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -216,7 +217,7 @@ class TestMemberCancelInviteApi:
assert status == 200
assert result["result"] == "success"
def test_cancel_not_found(self, app):
def test_cancel_not_found(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -233,7 +234,7 @@ class TestMemberCancelInviteApi:
with pytest.raises(HTTPException):
method(api, "x")
def test_cancel_cannot_operate_self(self, app):
def test_cancel_cannot_operate_self(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -255,7 +256,7 @@ class TestMemberCancelInviteApi:
assert status == 400
def test_cancel_no_permission(self, app):
def test_cancel_no_permission(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -277,7 +278,7 @@ class TestMemberCancelInviteApi:
assert status == 403
def test_cancel_member_not_in_tenant(self, app):
def test_cancel_member_not_in_tenant(self, app: Flask):
api = MemberCancelInviteApi()
method = unwrap(api.delete)
@ -301,7 +302,7 @@ class TestMemberCancelInviteApi:
class TestMemberUpdateRoleApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -324,7 +325,7 @@ class TestMemberUpdateRoleApi:
assert result["result"] == "success"
def test_update_invalid_role(self, app):
def test_update_invalid_role(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -335,7 +336,7 @@ class TestMemberUpdateRoleApi:
assert status == 400
def test_update_member_not_found(self, app):
def test_update_member_not_found(self, app: Flask):
api = MemberUpdateRoleApi()
method = unwrap(api.put)
@ -354,7 +355,7 @@ class TestMemberUpdateRoleApi:
class TestDatasetOperatorMemberListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
@ -381,7 +382,7 @@ class TestDatasetOperatorMemberListApi:
assert status == 200
assert len(result["accounts"]) == 1
def test_get_no_tenant(self, app):
def test_get_no_tenant(self, app: Flask):
api = DatasetOperatorMemberListApi()
method = unwrap(api.get)
@ -396,7 +397,7 @@ class TestDatasetOperatorMemberListApi:
class TestSendOwnerTransferEmailApi:
def test_send_success(self, app):
def test_send_success(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -419,7 +420,7 @@ class TestSendOwnerTransferEmailApi:
assert result["result"] == "success"
def test_send_ip_limit(self, app):
def test_send_ip_limit(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -433,7 +434,7 @@ class TestSendOwnerTransferEmailApi:
with pytest.raises(EmailSendIpLimitError):
method(api)
def test_send_not_owner(self, app):
def test_send_not_owner(self, app: Flask):
api = SendOwnerTransferEmailApi()
method = unwrap(api.post)
@ -452,7 +453,7 @@ class TestSendOwnerTransferEmailApi:
class TestOwnerTransferCheckApi:
def test_check_invalid_code(self, app):
def test_check_invalid_code(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -477,7 +478,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(EmailCodeError):
method(api)
def test_rate_limited(self, app):
def test_rate_limited(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -498,7 +499,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(OwnerTransferLimitError):
method(api)
def test_invalid_token(self, app):
def test_invalid_token(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -520,7 +521,7 @@ class TestOwnerTransferCheckApi:
with pytest.raises(InvalidTokenError):
method(api)
def test_invalid_email(self, app):
def test_invalid_email(self, app: Flask):
api = OwnerTransferCheckApi()
method = unwrap(api.post)
@ -547,7 +548,7 @@ class TestOwnerTransferCheckApi:
class TestOwnerTransferApi:
def test_transfer_self(self, app):
def test_transfer_self(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
@ -564,7 +565,7 @@ class TestOwnerTransferApi:
with pytest.raises(CannotTransferOwnerToSelfError):
method(api, "1")
def test_invalid_token(self, app):
def test_invalid_token(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)
@ -582,7 +583,7 @@ class TestOwnerTransferApi:
with pytest.raises(InvalidTokenError):
method(api, "2")
def test_member_not_in_tenant(self, app):
def test_member_not_in_tenant(self, app: Flask):
api = OwnerTransfer()
method = unwrap(api.post)

View File

@ -1,6 +1,7 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from pydantic_core import ValidationError
from werkzeug.exceptions import Forbidden
@ -26,7 +27,7 @@ def unwrap(func):
class TestModelProviderListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ModelProviderListApi()
method = unwrap(api.get)
@ -47,7 +48,7 @@ class TestModelProviderListApi:
class TestModelProviderCredentialApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
@ -66,7 +67,7 @@ class TestModelProviderCredentialApi:
assert "credentials" in result
def test_get_invalid_uuid(self, app):
def test_get_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.get)
@ -80,7 +81,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_post_create_success(self, app):
def test_post_create_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
@ -102,7 +103,7 @@ class TestModelProviderCredentialApi:
assert result["result"] == "success"
assert status == 201
def test_post_create_validation_error(self, app):
def test_post_create_validation_error(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.post)
@ -122,7 +123,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValueError):
method(api, provider="openai")
def test_put_update_success(self, app):
def test_put_update_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
@ -143,7 +144,7 @@ class TestModelProviderCredentialApi:
assert result["result"] == "success"
def test_put_invalid_uuid(self, app):
def test_put_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.put)
@ -159,7 +160,7 @@ class TestModelProviderCredentialApi:
with pytest.raises(ValidationError):
method(api, provider="openai")
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ModelProviderCredentialApi()
method = unwrap(api.delete)
@ -183,7 +184,7 @@ class TestModelProviderCredentialApi:
class TestModelProviderCredentialSwitchApi:
def test_switch_success(self, app):
def test_switch_success(self, app: Flask):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
@ -204,7 +205,7 @@ class TestModelProviderCredentialSwitchApi:
assert result["result"] == "success"
def test_switch_invalid_uuid(self, app):
def test_switch_invalid_uuid(self, app: Flask):
api = ModelProviderCredentialSwitchApi()
method = unwrap(api.post)
@ -222,7 +223,7 @@ class TestModelProviderCredentialSwitchApi:
class TestModelProviderValidateApi:
def test_validate_success(self, app):
def test_validate_success(self, app: Flask):
api = ModelProviderValidateApi()
method = unwrap(api.post)
@ -243,7 +244,7 @@ class TestModelProviderValidateApi:
assert result["result"] == "success"
def test_validate_failure(self, app):
def test_validate_failure(self, app: Flask):
api = ModelProviderValidateApi()
method = unwrap(api.post)
@ -266,7 +267,7 @@ class TestModelProviderValidateApi:
class TestModelProviderIconApi:
def test_icon_success(self, app):
def test_icon_success(self, app: Flask):
api = ModelProviderIconApi()
with (
@ -280,7 +281,7 @@ class TestModelProviderIconApi:
assert response.mimetype == "image/png"
def test_icon_not_found(self, app):
def test_icon_not_found(self, app: Flask):
api = ModelProviderIconApi()
with (
@ -295,7 +296,7 @@ class TestModelProviderIconApi:
class TestPreferredProviderTypeUpdateApi:
def test_update_success(self, app):
def test_update_success(self, app: Flask):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
@ -316,7 +317,7 @@ class TestPreferredProviderTypeUpdateApi:
assert result["result"] == "success"
def test_invalid_enum(self, app):
def test_invalid_enum(self, app: Flask):
api = PreferredProviderTypeUpdateApi()
method = unwrap(api.post)
@ -334,7 +335,7 @@ class TestPreferredProviderTypeUpdateApi:
class TestModelProviderPaymentCheckoutUrlApi:
def test_checkout_success(self, app):
def test_checkout_success(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
@ -359,7 +360,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
assert "url" in result
def test_invalid_provider(self, app):
def test_invalid_provider(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)
@ -367,7 +368,7 @@ class TestModelProviderPaymentCheckoutUrlApi:
with pytest.raises(ValueError):
method(api, provider="openai")
def test_permission_denied(self, app):
def test_permission_denied(self, app: Flask):
api = ModelProviderPaymentCheckoutUrlApi()
method = unwrap(api.get)

View File

@ -72,7 +72,7 @@ class TestDefaultModelApi:
assert result["result"] == "success"
def test_get_returns_empty_when_no_default(self, app):
def test_get_returns_empty_when_no_default(self, app: Flask):
api = DefaultModelApi()
method = unwrap(api.get)
@ -154,7 +154,7 @@ class TestModelProviderModelApi:
assert status == 204
def test_get_models_returns_empty(self, app):
def test_get_models_returns_empty(self, app: Flask):
api = ModelProviderModelApi()
method = unwrap(api.get)
@ -224,7 +224,7 @@ class TestModelProviderModelCredentialApi:
assert status == 201
def test_get_empty_credentials(self, app):
def test_get_empty_credentials(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.get)
@ -242,7 +242,7 @@ class TestModelProviderModelCredentialApi:
assert result["credentials"] == {}
def test_delete_success(self, app):
def test_delete_success(self, app: Flask):
api = ModelProviderModelCredentialApi()
method = unwrap(api.delete)
@ -416,7 +416,7 @@ class TestParameterAndAvailableModels:
assert "data" in result
def test_empty_rules(self, app):
def test_empty_rules(self, app: Flask):
api = ModelProviderModelParameterRuleApi()
method = unwrap(api.get)
@ -431,7 +431,7 @@ class TestParameterAndAvailableModels:
assert result["data"] == []
def test_no_models(self, app):
def test_no_models(self, app: Flask):
api = ModelProviderAvailableModelApi()
method = unwrap(api.get)

View File

@ -2,6 +2,7 @@ import io
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden
@ -61,7 +62,7 @@ def tenant():
class TestPluginListLatestVersionsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginListLatestVersionsApi()
method = unwrap(api.post)
@ -77,7 +78,7 @@ class TestPluginListLatestVersionsApi:
assert "versions" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginListLatestVersionsApi()
method = unwrap(api.post)
@ -95,7 +96,7 @@ class TestPluginListLatestVersionsApi:
class TestPluginDebuggingKeyApi:
def test_debugging_key_success(self, app):
def test_debugging_key_success(self, app: Flask):
api = PluginDebuggingKeyApi()
method = unwrap(api.get)
@ -108,7 +109,7 @@ class TestPluginDebuggingKeyApi:
assert result["key"] == "k"
def test_debugging_key_error(self, app):
def test_debugging_key_error(self, app: Flask):
api = PluginDebuggingKeyApi()
method = unwrap(api.get)
@ -125,7 +126,7 @@ class TestPluginDebuggingKeyApi:
class TestPluginListApi:
def test_plugin_list(self, app):
def test_plugin_list(self, app: Flask):
api = PluginListApi()
method = unwrap(api.get)
@ -142,7 +143,7 @@ class TestPluginListApi:
class TestPluginIconApi:
def test_plugin_icon(self, app):
def test_plugin_icon(self, app: Flask):
api = PluginIconApi()
method = unwrap(api.get)
@ -156,7 +157,7 @@ class TestPluginIconApi:
class TestPluginAssetApi:
def test_plugin_asset(self, app):
def test_plugin_asset(self, app: Flask):
api = PluginAssetApi()
method = unwrap(api.get)
@ -171,7 +172,7 @@ class TestPluginAssetApi:
class TestPluginUploadFromPkgApi:
def test_upload_pkg_success(self, app):
def test_upload_pkg_success(self, app: Flask):
api = PluginUploadFromPkgApi()
method = unwrap(api.post)
@ -188,7 +189,7 @@ class TestPluginUploadFromPkgApi:
assert result["ok"] is True
def test_upload_pkg_too_large(self, app):
def test_upload_pkg_too_large(self, app: Flask):
api = PluginUploadFromPkgApi()
method = unwrap(api.post)
@ -210,7 +211,7 @@ class TestPluginUploadFromPkgApi:
class TestPluginInstallFromPkgApi:
def test_install_from_pkg(self, app):
def test_install_from_pkg(self, app: Flask):
api = PluginInstallFromPkgApi()
method = unwrap(api.post)
@ -229,7 +230,7 @@ class TestPluginInstallFromPkgApi:
class TestPluginUninstallApi:
def test_uninstall(self, app):
def test_uninstall(self, app: Flask):
api = PluginUninstallApi()
method = unwrap(api.post)
@ -246,7 +247,7 @@ class TestPluginUninstallApi:
class TestPluginChangePermissionApi:
def test_change_permission_forbidden(self, app):
def test_change_permission_forbidden(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
@ -264,7 +265,7 @@ class TestPluginChangePermissionApi:
with pytest.raises(Forbidden):
method(api)
def test_change_permission_success(self, app):
def test_change_permission_success(self, app: Flask):
api = PluginChangePermissionApi()
method = unwrap(api.post)
@ -286,7 +287,7 @@ class TestPluginChangePermissionApi:
class TestPluginFetchPermissionApi:
def test_fetch_permission_default(self, app):
def test_fetch_permission_default(self, app: Flask):
api = PluginFetchPermissionApi()
method = unwrap(api.get)
@ -319,7 +320,7 @@ class TestPluginFetchDynamicSelectOptionsApi:
class TestPluginReadmeApi:
def test_fetch_readme(self, app):
def test_fetch_readme(self, app: Flask):
api = PluginReadmeApi()
method = unwrap(api.get)
@ -334,7 +335,7 @@ class TestPluginReadmeApi:
class TestPluginListInstallationsFromIdsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -352,7 +353,7 @@ class TestPluginListInstallationsFromIdsApi:
assert "plugins" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginListInstallationsFromIdsApi()
method = unwrap(api.post)
@ -371,7 +372,7 @@ class TestPluginListInstallationsFromIdsApi:
class TestPluginUploadFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -388,7 +389,7 @@ class TestPluginUploadFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUploadFromGithubApi()
method = unwrap(api.post)
@ -407,7 +408,7 @@ class TestPluginUploadFromGithubApi:
class TestPluginUploadFromBundleApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUploadFromBundleApi()
method = unwrap(api.post)
@ -430,7 +431,7 @@ class TestPluginUploadFromBundleApi:
assert result["ok"] is True
def test_too_large(self, app):
def test_too_large(self, app: Flask):
api = PluginUploadFromBundleApi()
method = unwrap(api.post)
@ -458,7 +459,7 @@ class TestPluginUploadFromBundleApi:
class TestPluginInstallFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginInstallFromGithubApi()
method = unwrap(api.post)
@ -478,7 +479,7 @@ class TestPluginInstallFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginInstallFromGithubApi()
method = unwrap(api.post)
@ -502,7 +503,7 @@ class TestPluginInstallFromGithubApi:
class TestPluginInstallFromMarketplaceApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginInstallFromMarketplaceApi()
method = unwrap(api.post)
@ -520,7 +521,7 @@ class TestPluginInstallFromMarketplaceApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginInstallFromMarketplaceApi()
method = unwrap(api.post)
@ -539,7 +540,7 @@ class TestPluginInstallFromMarketplaceApi:
class TestPluginFetchMarketplacePkgApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchMarketplacePkgApi()
method = unwrap(api.get)
@ -552,7 +553,7 @@ class TestPluginFetchMarketplacePkgApi:
assert "manifest" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchMarketplacePkgApi()
method = unwrap(api.get)
@ -569,7 +570,7 @@ class TestPluginFetchMarketplacePkgApi:
class TestPluginFetchManifestApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchManifestApi()
method = unwrap(api.get)
@ -585,7 +586,7 @@ class TestPluginFetchManifestApi:
assert "manifest" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchManifestApi()
method = unwrap(api.get)
@ -602,7 +603,7 @@ class TestPluginFetchManifestApi:
class TestPluginFetchInstallTasksApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchInstallTasksApi()
method = unwrap(api.get)
@ -615,7 +616,7 @@ class TestPluginFetchInstallTasksApi:
assert "tasks" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchInstallTasksApi()
method = unwrap(api.get)
@ -632,7 +633,7 @@ class TestPluginFetchInstallTasksApi:
class TestPluginFetchInstallTaskApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchInstallTaskApi()
method = unwrap(api.get)
@ -645,7 +646,7 @@ class TestPluginFetchInstallTaskApi:
assert "task" in result
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchInstallTaskApi()
method = unwrap(api.get)
@ -662,7 +663,7 @@ class TestPluginFetchInstallTaskApi:
class TestPluginDeleteInstallTaskApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteInstallTaskApi()
method = unwrap(api.post)
@ -675,7 +676,7 @@ class TestPluginDeleteInstallTaskApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteInstallTaskApi()
method = unwrap(api.post)
@ -692,7 +693,7 @@ class TestPluginDeleteInstallTaskApi:
class TestPluginDeleteAllInstallTaskItemsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteAllInstallTaskItemsApi()
method = unwrap(api.post)
@ -707,7 +708,7 @@ class TestPluginDeleteAllInstallTaskItemsApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteAllInstallTaskItemsApi()
method = unwrap(api.post)
@ -724,7 +725,7 @@ class TestPluginDeleteAllInstallTaskItemsApi:
class TestPluginDeleteInstallTaskItemApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginDeleteInstallTaskItemApi()
method = unwrap(api.post)
@ -737,7 +738,7 @@ class TestPluginDeleteInstallTaskItemApi:
assert result["success"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginDeleteInstallTaskItemApi()
method = unwrap(api.post)
@ -754,7 +755,7 @@ class TestPluginDeleteInstallTaskItemApi:
class TestPluginUpgradeFromMarketplaceApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -775,7 +776,7 @@ class TestPluginUpgradeFromMarketplaceApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUpgradeFromMarketplaceApi()
method = unwrap(api.post)
@ -797,7 +798,7 @@ class TestPluginUpgradeFromMarketplaceApi:
class TestPluginUpgradeFromGithubApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginUpgradeFromGithubApi()
method = unwrap(api.post)
@ -821,7 +822,7 @@ class TestPluginUpgradeFromGithubApi:
assert result["ok"] is True
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginUpgradeFromGithubApi()
method = unwrap(api.post)
@ -846,7 +847,7 @@ class TestPluginUpgradeFromGithubApi:
class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
@ -873,7 +874,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
assert result["options"] == [1]
def test_daemon_error(self, app):
def test_daemon_error(self, app: Flask):
api = PluginFetchDynamicSelectOptionsWithCredentialsApi()
method = unwrap(api.post)
@ -901,7 +902,7 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi:
class TestPluginChangePreferencesApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginChangePreferencesApi()
method = unwrap(api.post)
@ -931,7 +932,7 @@ class TestPluginChangePreferencesApi:
assert result["success"] is True
def test_permission_fail(self, app):
def test_permission_fail(self, app: Flask):
api = PluginChangePreferencesApi()
method = unwrap(api.post)
@ -962,7 +963,7 @@ class TestPluginChangePreferencesApi:
class TestPluginFetchPreferencesApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginFetchPreferencesApi()
method = unwrap(api.get)
@ -996,7 +997,7 @@ class TestPluginFetchPreferencesApi:
class TestPluginAutoUpgradeExcludePluginApi:
def test_success(self, app):
def test_success(self, app: Flask):
api = PluginAutoUpgradeExcludePluginApi()
method = unwrap(api.post)
@ -1011,7 +1012,7 @@ class TestPluginAutoUpgradeExcludePluginApi:
assert result["success"] is True
def test_fail(self, app):
def test_fail(self, app: Flask):
api = PluginAutoUpgradeExcludePluginApi()
method = unwrap(api.post)

View File

@ -2,6 +2,7 @@ from io import BytesIO
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Unauthorized
@ -37,7 +38,7 @@ def unwrap(func):
class TestTenantListApi:
def test_get_success_saas_path(self, app):
def test_get_success_saas_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -85,7 +86,7 @@ class TestTenantListApi:
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app):
def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app: Flask):
"""Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used.
billing.enabled is mocked False to prove the endpoint does not gate on it for this path
@ -140,7 +141,7 @@ class TestTenantListApi:
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_called_once_with("t2")
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app: Flask):
"""Test fallback to FeatureService when bulk billing returns empty result.
BillingService.get_plan_bulk catches exceptions internally and returns empty dict,
@ -197,7 +198,7 @@ class TestTenantListApi:
assert get_features_mock.call_count == 2
logger_warning_mock.assert_called_once()
def test_get_billing_disabled_community_path(self, app):
def test_get_billing_disabled_community_path(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -236,7 +237,7 @@ class TestTenantListApi:
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
def test_get_enterprise_only_skips_feature_service(self, app):
def test_get_enterprise_only_skips_feature_service(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -276,7 +277,7 @@ class TestTenantListApi:
assert result["workspaces"][1]["current"] is True
get_features_mock.assert_not_called()
def test_get_enterprise_only_with_empty_tenants(self, app):
def test_get_enterprise_only_with_empty_tenants(self, app: Flask):
api = TenantListApi()
method = unwrap(api.get)
@ -302,7 +303,7 @@ class TestTenantListApi:
class TestWorkspaceListApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
@ -324,7 +325,7 @@ class TestWorkspaceListApi:
assert result["total"] == 1
assert result["has_more"] is False
def test_get_has_next_true(self, app):
def test_get_has_next_true(self, app: Flask):
api = WorkspaceListApi()
method = unwrap(api.get)
@ -355,7 +356,7 @@ class TestWorkspaceListApi:
class TestTenantApi:
def test_post_active_tenant(self, app):
def test_post_active_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -375,7 +376,7 @@ class TestTenantApi:
assert status == 200
assert result["id"] == "t1"
def test_post_archived_with_switch(self, app):
def test_post_archived_with_switch(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -397,7 +398,7 @@ class TestTenantApi:
assert result["id"] == "new"
def test_post_archived_no_tenant(self, app):
def test_post_archived_no_tenant(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -411,7 +412,7 @@ class TestTenantApi:
with pytest.raises(Unauthorized):
method(api)
def test_post_info_path(self, app):
def test_post_info_path(self, app: Flask):
api = TenantApi()
method = unwrap(api.post)
@ -454,7 +455,7 @@ class TestTenantInfoResponse:
class TestSwitchWorkspaceApi:
def test_switch_success(self, app):
def test_switch_success(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -477,7 +478,7 @@ class TestSwitchWorkspaceApi:
assert result["result"] == "success"
def test_switch_not_linked(self, app):
def test_switch_not_linked(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -493,7 +494,7 @@ class TestSwitchWorkspaceApi:
with pytest.raises(AccountNotLinkTenantError):
method(api)
def test_switch_tenant_not_found(self, app):
def test_switch_tenant_not_found(self, app: Flask):
api = SwitchWorkspaceApi()
method = unwrap(api.post)
@ -515,7 +516,7 @@ class TestSwitchWorkspaceApi:
class TestCustomConfigWorkspaceApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
@ -538,7 +539,7 @@ class TestCustomConfigWorkspaceApi:
assert result["result"] == "success"
def test_logo_fallback(self, app):
def test_logo_fallback(self, app: Flask):
api = CustomConfigWorkspaceApi()
method = unwrap(api.post)
@ -569,7 +570,7 @@ class TestCustomConfigWorkspaceApi:
class TestWebappLogoWorkspaceApi:
def test_no_file(self, app):
def test_no_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -582,7 +583,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(NoFileUploadedError):
method(api)
def test_too_many_files(self, app):
def test_too_many_files(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -601,7 +602,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(TooManyFilesError):
method(api)
def test_invalid_extension(self, app):
def test_invalid_extension(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -616,7 +617,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(UnsupportedFileTypeError):
method(api)
def test_upload_success(self, app):
def test_upload_success(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -648,7 +649,7 @@ class TestWebappLogoWorkspaceApi:
assert status == 201
assert result["id"] == "file1"
def test_filename_missing(self, app):
def test_filename_missing(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -672,7 +673,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(FilenameNotExistsError):
method(api)
def test_file_too_large(self, app):
def test_file_too_large(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -701,7 +702,7 @@ class TestWebappLogoWorkspaceApi:
with pytest.raises(FileTooLargeError):
method(api)
def test_service_unsupported_file(self, app):
def test_service_unsupported_file(self, app: Flask):
api = WebappLogoWorkspaceApi()
method = unwrap(api.post)
@ -732,7 +733,7 @@ class TestWebappLogoWorkspaceApi:
class TestWorkspaceInfoApi:
def test_post_success(self, app):
def test_post_success(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
@ -756,7 +757,7 @@ class TestWorkspaceInfoApi:
assert result["result"] == "success"
def test_no_current_tenant(self, app):
def test_no_current_tenant(self, app: Flask):
api = WorkspaceInfoApi()
method = unwrap(api.post)
@ -774,7 +775,7 @@ class TestWorkspaceInfoApi:
class TestWorkspacePermissionApi:
def test_get_success(self, app):
def test_get_success(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)
@ -799,7 +800,7 @@ class TestWorkspacePermissionApi:
assert status == 200
assert result["workspace_id"] == "t1"
def test_no_current_tenant(self, app):
def test_no_current_tenant(self, app: Flask):
api = WorkspacePermissionApi()
method = unwrap(api.get)

View File

@ -41,7 +41,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_chat_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving parameters for a chat app."""
# Arrange
@ -91,7 +91,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_for_workflow_app(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving parameters for a workflow app."""
# Arrange
@ -136,7 +136,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_chat_config_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test that AppUnavailableError is raised when chat app has no config."""
# Arrange
@ -174,7 +174,7 @@ class TestAppParameterApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_parameters_raises_error_when_workflow_missing(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test that AppUnavailableError is raised when workflow app has no workflow."""
# Arrange
@ -234,7 +234,14 @@ class TestAppMetaApi:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.app.app.AppService")
def test_get_app_meta(
self, mock_app_service, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self,
mock_app_service,
mock_db,
mock_validate_token,
mock_current_app,
mock_user_logged_in,
app: Flask,
mock_app_model,
):
"""Test retrieving app metadata via AppService."""
# Arrange
@ -310,7 +317,7 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, mock_app_model
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, mock_app_model
):
"""Test retrieving basic app information."""
mock_current_app.login_manager = Mock()
@ -402,7 +409,9 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.current_app")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_with_no_tags(self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app):
def test_get_app_info_with_no_tags(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask
):
"""Test retrieving app info when app has no tags."""
# Arrange
mock_current_app.login_manager = Mock()
@ -453,7 +462,7 @@ class TestAppInfoApi:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.db")
def test_get_app_info_returns_correct_mode(
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app, app_mode
self, mock_db, mock_validate_token, mock_current_app, mock_user_logged_in, app: Flask, app_mode
):
"""Test that all app modes are correctly returned."""
# Arrange

View File

@ -13,6 +13,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
@ -190,7 +191,7 @@ class TestAudioServiceMockedBehavior:
class TestAudioApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = AudioApi()
handler = _unwrap(api.post)
@ -216,7 +217,7 @@ class TestAudioApi:
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = AudioApi()
handler = _unwrap(api.post)
@ -227,7 +228,7 @@ class TestAudioApi:
with pytest.raises(expected):
handler(api, app_model=app_model, end_user=end_user)
def test_unhandled_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_unhandled_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
)
@ -242,7 +243,7 @@ class TestAudioApi:
class TestTextApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = TextApi()
@ -259,7 +260,7 @@ class TestTextApi:
assert response == {"audio": "ok"}
def test_error_mapping(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())
)

View File

@ -16,6 +16,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from pydantic import ValidationError
from werkzeug.exceptions import BadRequest, NotFound
@ -295,7 +296,7 @@ class TestCompletionControllerLogic:
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
def test_completion_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask):
"""Test CompletionApi.post success path."""
from controllers.service_api.app.completion import CompletionApi
@ -320,7 +321,7 @@ class TestCompletionControllerLogic:
mock_generate_service.generate.assert_called_once()
@patch("controllers.service_api.app.completion.service_api_ns")
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app):
def test_completion_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask):
"""Test CompletionApi.post with wrong app mode."""
from controllers.service_api.app.completion import CompletionApi
@ -334,7 +335,7 @@ class TestCompletionControllerLogic:
@patch("controllers.service_api.app.completion.service_api_ns")
@patch("controllers.service_api.app.completion.AppGenerateService")
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app):
def test_chat_api_post_success(self, mock_generate_service, mock_service_api_ns, app: Flask):
"""Test ChatApi.post success path."""
from controllers.service_api.app.completion import ChatApi
@ -355,7 +356,7 @@ class TestCompletionControllerLogic:
assert response == {"text": "compacted"}
@patch("controllers.service_api.app.completion.service_api_ns")
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app):
def test_chat_api_post_wrong_app_mode(self, mock_service_api_ns, app: Flask):
"""Test ChatApi.post with wrong app mode."""
from controllers.service_api.app.completion import ChatApi
@ -368,7 +369,7 @@ class TestCompletionControllerLogic:
ChatApi().post.__wrapped__(ChatApi(), mock_app_model, mock_end_user)
@patch("controllers.service_api.app.completion.AppTaskService")
def test_completion_stop_api_success(self, mock_task_service, app):
def test_completion_stop_api_success(self, mock_task_service, app: Flask):
"""Test CompletionStopApi.post success."""
from controllers.service_api.app.completion import CompletionStopApi
@ -385,7 +386,7 @@ class TestCompletionControllerLogic:
mock_task_service.stop_task.assert_called_once()
@patch("controllers.service_api.app.completion.AppTaskService")
def test_chat_stop_api_success(self, mock_task_service, app):
def test_chat_stop_api_success(self, mock_task_service, app: Flask):
"""Test ChatStopApi.post success."""
from controllers.service_api.app.completion import ChatStopApi

View File

@ -20,6 +20,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
import services
@ -504,7 +505,7 @@ class TestConversationApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_list_last_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_list_last_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _BeginStub:
def __enter__(self):
return SimpleNamespace()
@ -552,7 +553,7 @@ class TestConversationDetailApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_delete_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"delete",
@ -570,7 +571,7 @@ class TestConversationDetailApiController:
class TestConversationRenameApiController:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"rename",
@ -602,7 +603,7 @@ class TestConversationVariablesApiController:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"get_conversational_variable",
@ -621,7 +622,7 @@ class TestConversationVariablesApiController:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, c_id="00000000-0000-0000-0000-000000000001")
def test_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,
@ -661,7 +662,7 @@ class TestConversationVariablesApiController:
class TestConversationVariableDetailApiController:
def test_update_type_mismatch(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_type_mismatch(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
@ -687,7 +688,7 @@ class TestConversationVariableDetailApiController:
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
ConversationService,
"update_conversation_variable",
@ -713,7 +714,7 @@ class TestConversationVariableDetailApiController:
variable_id="00000000-0000-0000-0000-000000000002",
)
def test_update_success_serializes_response(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_update_success_serializes_response(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
created_at = datetime(2026, 1, 2, 3, 4, 5, tzinfo=UTC)
monkeypatch.setattr(
ConversationService,

View File

@ -16,6 +16,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from controllers.common.errors import (
FilenameNotExistsError,
@ -282,7 +283,7 @@ class TestFileApiPost:
assert status == 201
mock_file_svc_cls.return_value.upload_file.assert_called_once()
def test_upload_no_file(self, app, mock_app_model, mock_end_user):
def test_upload_no_file(self, app: Flask, mock_app_model, mock_end_user):
"""Test NoFileUploadedError when no file in request."""
from controllers.service_api.app.file import FileApi
@ -296,7 +297,7 @@ class TestFileApiPost:
with pytest.raises(NoFileUploadedError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_too_many_files(self, app, mock_app_model, mock_end_user):
def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user):
"""Test TooManyFilesError when multiple files uploaded."""
from io import BytesIO
@ -317,7 +318,7 @@ class TestFileApiPost:
with pytest.raises(TooManyFilesError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_no_mimetype(self, app, mock_app_model, mock_end_user):
def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user):
"""Test UnsupportedFileTypeError when file has no mimetype."""
from io import BytesIO

View File

@ -11,6 +11,7 @@ from types import SimpleNamespace
from unittest.mock import ANY, MagicMock, Mock
import pytest
from flask import Flask
import services.app_generate_service as ags_module
from controllers.service_api.app.workflow_events import WorkflowEventsApi
@ -281,7 +282,7 @@ class TestHitlServiceApi:
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_workflow_events_snapshot_continue_on_pause_keeps_pause_open(
self, app, monkeypatch: pytest.MonkeyPatch
self, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
workflow_run = SimpleNamespace(
id="run-1",

View File

@ -19,6 +19,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
from controllers.service_api.app.error import NotChatAppError
@ -390,7 +391,7 @@ class TestMessageListApi:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_conversation_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_conversation_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
@ -409,7 +410,7 @@ class TestMessageListApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user)
def test_first_message_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_first_message_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"pagination_by_first_id",
@ -430,7 +431,7 @@ class TestMessageListApi:
class TestMessageFeedbackApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"create_feedback",
@ -452,7 +453,7 @@ class TestMessageFeedbackApi:
class TestAppGetFeedbacksApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
api = AppGetFeedbacksApi()
@ -476,7 +477,7 @@ class TestMessageSuggestedApi:
with pytest.raises(NotChatAppError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -492,7 +493,7 @@ class TestMessageSuggestedApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_disabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_disabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -508,7 +509,7 @@ class TestMessageSuggestedApi:
with pytest.raises(BadRequest):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_internal_error(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_internal_error(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",
@ -524,7 +525,7 @@ class TestMessageSuggestedApi:
with pytest.raises(InternalServerError):
handler(api, app_model=app_model, end_user=end_user, message_id="m1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
MessageService,
"get_suggested_questions_after_answer",

View File

@ -20,6 +20,7 @@ from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.service_api.app.error import NotWorkflowAppError
@ -366,7 +367,7 @@ class TestWorkflowRunRepository:
class TestWorkflowRunDetailApi:
def test_not_workflow_app(self, app) -> None:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -397,7 +398,7 @@ class TestWorkflowRunDetailApi:
class TestWorkflowRunApi:
def test_not_workflow_app(self, app) -> None:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -407,7 +408,7 @@ class TestWorkflowRunApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user)
def test_rate_limit(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_rate_limit(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -425,7 +426,7 @@ class TestWorkflowRunApi:
class TestWorkflowRunByIdApi:
def test_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -441,7 +442,7 @@ class TestWorkflowRunByIdApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, workflow_id="w1")
def test_draft_workflow(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_draft_workflow(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AppGenerateService,
"generate",
@ -459,7 +460,7 @@ class TestWorkflowRunByIdApi:
class TestWorkflowTaskStopApi:
def test_wrong_mode(self, app) -> None:
def test_wrong_mode(self, app: Flask) -> None:
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
@ -469,7 +470,7 @@ class TestWorkflowTaskStopApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="t1")
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
stop_mock = Mock()
send_mock = Mock()
monkeypatch.setattr(AppQueueManager, "set_stop_flag_no_user_check", stop_mock)
@ -489,7 +490,7 @@ class TestWorkflowTaskStopApi:
class TestWorkflowAppLogApi:
def test_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
class _BeginStub:
def __enter__(self):
return SimpleNamespace()
@ -557,7 +558,7 @@ class TestWorkflowRunDetailApiGet:
self,
mock_db,
mock_repo_factory,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow run detail retrieval."""
@ -579,7 +580,7 @@ class TestWorkflowRunDetailApiGet:
assert result["status"] == "succeeded"
@patch("controllers.service_api.app.workflow.db")
def test_get_workflow_run_wrong_app_mode(self, mock_db, app):
def test_get_workflow_run_wrong_app_mode(self, mock_db, app: Flask):
"""Test NotWorkflowAppError when app mode is not workflow or advanced_chat."""
from controllers.service_api.app.workflow import WorkflowRunDetailApi
@ -604,7 +605,7 @@ class TestWorkflowTaskStopApiPost:
self,
mock_queue_mgr,
mock_graph_mgr,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow task stop."""
@ -624,7 +625,7 @@ class TestWorkflowTaskStopApiPost:
mock_graph_mgr.assert_called_once()
mock_graph_mgr.return_value.send_stop_command.assert_called_once_with("task-1")
def test_stop_workflow_task_wrong_app_mode(self, app):
def test_stop_workflow_task_wrong_app_mode(self, app: Flask):
"""Test NotWorkflowAppError when app mode is not workflow."""
from controllers.service_api.app.workflow import WorkflowTaskStopApi
@ -649,7 +650,7 @@ class TestWorkflowAppLogApiGet:
self,
mock_db,
mock_wf_svc_cls,
app,
app: Flask,
mock_workflow_app,
):
"""Test successful workflow log retrieval."""

View File

@ -9,6 +9,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.app.error import NotWorkflowAppError
@ -41,7 +42,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotWorkflowAppError):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_not_found(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
@ -52,7 +53,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_workflow_run_permission_denied(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_run_permission_denied(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -70,7 +71,7 @@ class TestWorkflowEventsApi:
with pytest.raises(NotFound):
handler(api, app_model=app_model, end_user=end_user, task_id="run-1")
def test_finished_run_returns_sse(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_finished_run_returns_sse(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -103,7 +104,7 @@ class TestWorkflowEventsApi:
assert payload["task_id"] == "run-1"
assert payload["event"] == "workflow_finished"
def test_running_run_streams_events(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_running_run_streams_events(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",
@ -135,7 +136,7 @@ class TestWorkflowEventsApi:
)
workflow_generator.convert_to_event_stream.assert_called_once_with(["raw-event"])
def test_running_run_with_snapshot(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_running_run_with_snapshot(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow_run = SimpleNamespace(
id="run-1",
app_id="app-1",

View File

@ -23,6 +23,7 @@ from datetime import UTC, datetime
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import Forbidden, NotFound
@ -373,7 +374,7 @@ class TestDatasourcePluginsApiGet:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_success(self, mock_svc_cls, mock_db, app):
def test_get_plugins_success(self, mock_svc_cls, mock_db, app: Flask):
"""Test successful retrieval of datasource plugins."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
@ -396,7 +397,7 @@ class TestDatasourcePluginsApiGet:
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_get_plugins_not_found(self, mock_db, app):
def test_get_plugins_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -407,7 +408,7 @@ class TestDatasourcePluginsApiGet:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app):
def test_get_plugins_empty_list(self, mock_svc_cls, mock_db, app: Flask):
"""Test empty plugin list."""
mock_db.session.scalar.return_value = Mock()
mock_svc_instance = Mock()
@ -439,7 +440,7 @@ class TestDatasourceNodeRunApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.RagPipelineService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app):
def test_post_success(self, mock_ns, mock_db, mock_svc_cls, mock_current_user, mock_gen, mock_helper, app: Flask):
"""Test successful datasource node run."""
tenant_id = str(uuid.uuid4())
dataset_id = str(uuid.uuid4())
@ -473,7 +474,7 @@ class TestDatasourceNodeRunApiPost:
mock_svc_instance.run_datasource_workflow_node.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
def test_post_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -488,7 +489,7 @@ class TestDatasourceNodeRunApiPost:
)
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app):
def test_post_fails_when_current_user_not_account(self, mock_ns, mock_db, app: Flask):
"""Test AssertionError when current_user is not an Account instance."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
@ -549,7 +550,7 @@ class TestPipelineRunApiPost:
mock_gen_svc.generate.assert_called_once()
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_post_not_found(self, mock_db, app):
def test_post_not_found(self, mock_db, app: Flask):
"""Test NotFound when dataset check fails."""
mock_db.session.scalar.return_value = None
@ -561,7 +562,7 @@ class TestPipelineRunApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user", new="not_account")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.service_api_ns")
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app):
def test_post_forbidden_non_account_user(self, mock_ns, mock_db, app: Flask):
"""Test Forbidden when current_user is not an Account."""
mock_db.session.scalar.return_value = Mock()
mock_ns.payload = {
@ -585,7 +586,7 @@ class TestFileUploadApiPost:
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.FileService")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.current_user")
@patch("controllers.service_api.dataset.rag_pipeline.rag_pipeline_workflow.db")
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app):
def test_upload_success(self, mock_db, mock_current_user, mock_file_svc_cls, app: Flask):
"""Test successful file upload."""
mock_current_user.__bool__ = Mock(return_value=True)
@ -621,7 +622,7 @@ class TestFileUploadApiPost:
assert response["name"] == "doc.pdf"
assert response["extension"] == "pdf"
def test_upload_no_file(self, app):
def test_upload_no_file(self, app: Flask):
"""Test error when no file is uploaded."""
with app.test_request_context(
"/datasets/pipeline/file-upload",

View File

@ -18,6 +18,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.segment import (
@ -782,7 +783,7 @@ class TestSegmentApiGet:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -893,7 +894,7 @@ class TestSegmentApiPost:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -946,7 +947,7 @@ class TestSegmentApiPost:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -989,7 +990,7 @@ class TestSegmentApiPost:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1041,7 +1042,7 @@ class TestDatasetSegmentApiDelete:
mock_doc_svc,
mock_dataset_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1086,7 +1087,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1128,7 +1129,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_doc_svc,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1162,7 +1163,7 @@ class TestDatasetSegmentApiDelete:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1232,7 +1233,7 @@ class TestDatasetSegmentApiUpdate:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1282,7 +1283,7 @@ class TestDatasetSegmentApiUpdate:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1322,7 +1323,7 @@ class TestDatasetSegmentApiUpdate:
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1374,7 +1375,7 @@ class TestDatasetSegmentApiGetSingle:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1421,7 +1422,7 @@ class TestDatasetSegmentApiGetSingle:
mock_seg_svc,
mock_marshal,
mock_summary_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
mock_segment,
@ -1460,7 +1461,7 @@ class TestDatasetSegmentApiGetSingle:
self,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1491,7 +1492,7 @@ class TestDatasetSegmentApiGetSingle:
mock_account_fn,
mock_dataset_svc,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1526,7 +1527,7 @@ class TestDatasetSegmentApiGetSingle:
mock_dataset_svc,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1570,7 +1571,7 @@ class TestChildChunkApiGet:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1609,7 +1610,7 @@ class TestChildChunkApiGet:
self,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1638,7 +1639,7 @@ class TestChildChunkApiGet:
mock_db,
mock_account_fn,
mock_doc_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1670,7 +1671,7 @@ class TestChildChunkApiGet:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1729,7 +1730,7 @@ class TestChildChunkApiPost:
mock_doc_svc,
mock_seg_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1771,7 +1772,7 @@ class TestChildChunkApiPost:
mock_feature_svc,
mock_db,
mock_account_fn,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1809,7 +1810,7 @@ class TestChildChunkApiPost:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1863,7 +1864,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1913,7 +1914,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1954,7 +1955,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -1994,7 +1995,7 @@ class TestDatasetChildChunkApiDelete:
mock_account_fn,
mock_doc_svc,
mock_seg_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):

View File

@ -19,6 +19,7 @@ import uuid
from unittest.mock import Mock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.service_api.dataset.metadata import (
@ -76,7 +77,7 @@ class TestDatasetMetadataCreatePost:
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -106,7 +107,7 @@ class TestDatasetMetadataCreatePost:
def test_create_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -136,7 +137,7 @@ class TestDatasetMetadataCreateGet:
self,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -160,7 +161,7 @@ class TestDatasetMetadataCreateGet:
def test_get_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -201,7 +202,7 @@ class TestDatasetMetadataServiceApiPatch:
mock_dataset_svc,
mock_meta_svc,
mock_marshal,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -232,7 +233,7 @@ class TestDatasetMetadataServiceApiPatch:
def test_update_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -273,7 +274,7 @@ class TestDatasetMetadataServiceApiDelete:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -302,7 +303,7 @@ class TestDatasetMetadataServiceApiDelete:
def test_delete_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -336,7 +337,7 @@ class TestDatasetMetadataBuiltInFieldGet:
def test_get_built_in_fields_success(
self,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -382,7 +383,7 @@ class TestDatasetMetadataBuiltInFieldAction:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -414,7 +415,7 @@ class TestDatasetMetadataBuiltInFieldAction:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -441,7 +442,7 @@ class TestDatasetMetadataBuiltInFieldAction:
def test_action_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -485,7 +486,7 @@ class TestDocumentMetadataEditPost:
mock_current_user,
mock_dataset_svc,
mock_meta_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):
@ -513,7 +514,7 @@ class TestDocumentMetadataEditPost:
def test_update_documents_metadata_dataset_not_found(
self,
mock_dataset_svc,
app,
app: Flask,
mock_tenant,
mock_dataset,
):

View File

@ -5,6 +5,7 @@ Unit tests for Service API Index endpoint
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.service_api.index import IndexApi
@ -13,7 +14,7 @@ class TestIndexApi:
"""Test suite for IndexApi resource."""
@patch("controllers.service_api.index.dify_config", autospec=True)
def test_get_returns_api_info(self, mock_config, app):
def test_get_returns_api_info(self, mock_config, app: Flask):
"""Test that GET returns API metadata with correct structure."""
# Arrange
mock_config.project.version = "1.0.0-test"
@ -32,7 +33,7 @@ class TestIndexApi:
assert response["api_version"] == "v1"
assert response["server_version"] == "1.0.0-test"
def test_get_response_has_required_fields(self, app):
def test_get_response_has_required_fields(self, app: Flask):
"""Test that response contains all required fields."""
# Arrange
mock_config = MagicMock()

View File

@ -39,7 +39,7 @@ class TestValidateAndGetApiToken:
app.config["TESTING"] = True
return app
def test_missing_authorization_header(self, app):
def test_missing_authorization_header(self, app: Flask):
"""Test that Unauthorized is raised when Authorization header is missing."""
# Arrange
with app.test_request_context("/", method="GET"):
@ -50,7 +50,7 @@ class TestValidateAndGetApiToken:
validate_and_get_api_token("app")
assert "Authorization header must be provided" in str(exc_info.value)
def test_invalid_auth_scheme(self, app):
def test_invalid_auth_scheme(self, app: Flask):
"""Test that Unauthorized is raised when auth scheme is not Bearer."""
# Arrange
with app.test_request_context("/", method="GET", headers={"Authorization": "Basic token123"}):
@ -62,7 +62,7 @@ class TestValidateAndGetApiToken:
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
def test_valid_token_returns_api_token(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask):
"""Test that valid token returns the ApiToken object."""
# Arrange
mock_api_token = Mock(spec=ApiToken)
@ -84,7 +84,7 @@ class TestValidateAndGetApiToken:
@patch("controllers.service_api.wraps.record_token_usage")
@patch("controllers.service_api.wraps.ApiTokenCache")
@patch("controllers.service_api.wraps.fetch_token_with_single_flight")
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app):
def test_invalid_token_raises_unauthorized(self, mock_fetch_token, mock_cache_cls, mock_record_usage, app: Flask):
"""Test that invalid token raises Unauthorized."""
# Arrange
from werkzeug.exceptions import Unauthorized
@ -161,7 +161,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_not_found_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app no longer exists."""
# Arrange
mock_api_token = Mock()
@ -182,7 +182,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_status_abnormal_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app status is abnormal."""
# Arrange
mock_api_token = Mock()
@ -205,7 +205,7 @@ class TestValidateAppToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app):
def test_app_api_disabled_raises_forbidden(self, mock_validate_token, mock_db, app: Flask):
"""Test that Forbidden is raised when app API is disabled."""
# Arrange
mock_api_token = Mock()
@ -240,7 +240,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app):
def test_allows_when_under_limit(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that request is allowed when under resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -264,7 +264,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app):
def test_rejects_when_at_limit(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that Forbidden is raised when at resource limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -287,7 +287,7 @@ class TestCloudEditionBillingResourceCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app):
def test_allows_when_billing_disabled(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that request is allowed when billing is disabled."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -320,7 +320,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app):
def test_rejects_add_segment_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that add_segment is rejected in SANDBOX plan."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -342,7 +342,7 @@ class TestCloudEditionBillingKnowledgeLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_features")
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app):
def test_allows_other_operations_in_sandbox(self, mock_get_features, mock_validate_token, app: Flask):
"""Test that non-add_segment operations are allowed in SANDBOX."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -376,7 +376,7 @@ class TestCloudEditionBillingRateLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app):
def test_allows_within_rate_limit(self, mock_get_rate_limit, mock_validate_token, app: Flask):
"""Test that request is allowed when within rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -406,7 +406,7 @@ class TestCloudEditionBillingRateLimitCheck:
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.FeatureService.get_knowledge_rate_limit")
@patch("controllers.service_api.wraps.db")
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app):
def test_rejects_over_rate_limit(self, mock_db, mock_get_rate_limit, mock_validate_token, app: Flask):
"""Test that Forbidden is raised when over rate limit."""
# Arrange
mock_validate_token.return_value = Mock(tenant_id="tenant123")
@ -445,7 +445,7 @@ class TestValidateDatasetToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
@patch("controllers.service_api.wraps.current_app")
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app):
def test_valid_dataset_token(self, mock_current_app, mock_validate_token, mock_db, mock_user_logged_in, app: Flask):
"""Test that valid dataset token allows access."""
# Arrange
# Use standard Mock for login_manager
@ -487,7 +487,7 @@ class TestValidateDatasetToken:
@patch("controllers.service_api.wraps.db")
@patch("controllers.service_api.wraps.validate_and_get_api_token")
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app):
def test_dataset_not_found_raises_not_found(self, mock_validate_token, mock_db, app: Flask):
"""Test that NotFound is raised when dataset doesn't exist."""
# Arrange
mock_api_token = Mock()

View File

@ -13,7 +13,7 @@ class TestExternalDataFetch:
app = Flask(__name__)
return app
def test_fetch_success(self, app):
def test_fetch_success(self, app: Flask):
with app.app_context():
fetcher = ExternalDataFetch()
@ -79,7 +79,7 @@ class TestExternalDataFetch:
assert result_inputs == inputs
assert result_inputs is not inputs # Should be a copy
def test_fetch_with_none_variable(self, app):
def test_fetch_with_none_variable(self, app: Flask):
with app.app_context():
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={})
@ -95,7 +95,7 @@ class TestExternalDataFetch:
assert "var1" not in result_inputs
assert result_inputs == {"in": "val"}
def test_query_external_data_tool(self, app):
def test_query_external_data_tool(self, app: Flask):
fetcher = ExternalDataFetch()
tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"})

View File

@ -1,79 +0,0 @@
from unittest.mock import MagicMock, patch
from models.account import TenantPluginPermission
MODULE = "services.plugin.plugin_permission_service"
def _patched_session():
"""Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock()
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=False)
session.begin.return_value.__enter__ = MagicMock(return_value=session)
session.begin.return_value.__exit__ = MagicMock(return_value=False)
mock_factory = MagicMock()
mock_factory.create_session.return_value = session
patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, session
class TestGetPermission:
def test_returns_permission_when_found(self):
p1, session = _patched_session()
permission = MagicMock()
session.scalar.return_value = permission
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is permission
def test_returns_none_when_not_found(self):
p1, session = _patched_session()
session.scalar.return_value = None
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.get_permission("t1")
assert result is None
class TestChangePermission:
def test_creates_new_permission_when_not_exists(self):
p1, session = _patched_session()
session.scalar.return_value = None
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginPermission") as perm_cls:
perm_cls.return_value = MagicMock()
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.EVERYONE, TenantPluginPermission.DebugPermission.EVERYONE
)
assert result is True
session.begin.assert_called_once()
session.add.assert_called_once()
def test_updates_existing_permission(self):
p1, session = _patched_session()
existing = MagicMock()
session.scalar.return_value = existing
with p1:
from services.plugin.plugin_permission_service import PluginPermissionService
result = PluginPermissionService.change_permission(
"t1", TenantPluginPermission.InstallPermission.ADMINS, TenantPluginPermission.DebugPermission.ADMINS
)
assert result is True
session.begin.assert_called_once()
assert existing.install_permission == TenantPluginPermission.InstallPermission.ADMINS
assert existing.debug_permission == TenantPluginPermission.DebugPermission.ADMINS
session.add.assert_not_called()

51
docker/.env.default Normal file
View File

@ -0,0 +1,51 @@
# ------------------------------------------------------------------
# Minimal defaults for Docker Compose deployments.
#
# Keep local changes in .env. Use .env.example as the full reference
# for advanced and service-specific settings.
# ------------------------------------------------------------------
# Public URLs used when Dify generates links. Change these together when
# exposing Dify under another hostname, IP address, or port.
CONSOLE_WEB_URL=http://localhost
SERVICE_API_URL=http://localhost
APP_WEB_URL=http://localhost
FILES_URL=http://localhost
INTERNAL_FILES_URL=http://api:5001
TRIGGER_URL=http://localhost
ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id}
NEXT_PUBLIC_SOCKET_URL=ws://localhost
EXPOSE_PLUGIN_DEBUGGING_HOST=localhost
EXPOSE_PLUGIN_DEBUGGING_PORT=5003
# Built-in metadata database defaults.
DB_TYPE=postgresql
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
DB_HOST=db_postgres
DB_PORT=5432
DB_DATABASE=dify
# Built-in Redis defaults.
REDIS_HOST=redis
REDIS_PORT=6379
REDIS_PASSWORD=difyai123456
# Default file storage.
STORAGE_TYPE=opendal
OPENDAL_SCHEME=fs
OPENDAL_FS_ROOT=storage
# Default vector database.
VECTOR_STORE=weaviate
# Internal service authentication. Paired values must match.
PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi
PLUGIN_DIFY_INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Host ports.
EXPOSE_NGINX_PORT=80
EXPOSE_NGINX_SSL_PORT=443
# Docker Compose profiles for bundled services.
COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql}

View File

@ -1003,7 +1003,7 @@ NOTION_INTERNAL_SECRET=
# ------------------------------
# Mail type, support: resend, smtp, sendgrid
MAIL_TYPE=resend
MAIL_TYPE=
# Default send from email address, if not specified
# If using SendGrid, use the 'from' field for authentication if necessary.
@ -1011,7 +1011,7 @@ MAIL_DEFAULT_SEND_FROM=
# API-Key for the Resend email provider, used when MAIL_TYPE is `resend`.
RESEND_API_URL=https://api.resend.com
RESEND_API_KEY=your-resend-api-key
RESEND_API_KEY=
# SMTP server configuration, used when MAIL_TYPE is `smtp`
@ -1359,10 +1359,10 @@ NGINX_ENABLE_CERTBOT_CHALLENGE=false
# ------------------------------
# Email address (required to get certificates from Let's Encrypt)
CERTBOT_EMAIL=your_email@example.com
CERTBOT_EMAIL=
# Domain name
CERTBOT_DOMAIN=your_domain.com
CERTBOT_DOMAIN=
# certbot command options
# i.e: --force-renewal --dry-run --test-cert --debug

View File

@ -7,28 +7,28 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
- **Certbot Container**: `docker-compose.yaml` now contains `certbot` for managing SSL certificates. This container automatically renews certificates and ensures secure HTTPS connections.\
For more information, refer `docker/certbot/README.md`.
- **Persistent Environment Variables**: Environment variables are now managed through a `.env` file, ensuring that your configurations persist across deployments.
- **Persistent Environment Variables**: Default environment variables are managed through `.env.default`, while local overrides are stored in `.env`, ensuring that your configurations persist across deployments.
> What is `.env`? </br> </br>
> The `.env` file is a crucial component in Docker and Docker Compose environments, serving as a centralized configuration file where you can define environment variables that are accessible to the containers at runtime. This file simplifies the management of environment settings across different stages of development, testing, and production, providing consistency and ease of configuration to deployments.
> The `.env` file is a local override file. Keep it small by adding only the values that differ from `.env.default`. Use `.env.example` as the full reference when you need advanced configuration.
- **Unified Vector Database Services**: All vector database services are now managed from a single Docker Compose file `docker-compose.yaml`. You can switch between different vector databases by setting the `VECTOR_STORE` environment variable in your `.env` file.
- **Mandatory .env File**: A `.env` file is now required to run `docker compose up`. This file is crucial for configuring your deployment and for any custom settings to persist through upgrades.
- **Local .env Overrides**: The `dify-compose` and `dify-compose.ps1` wrappers create `.env` if it is missing and generate a persistent `SECRET_KEY` for this deployment.
### How to Deploy Dify with `docker-compose.yaml`
1. **Prerequisites**: Ensure Docker and Docker Compose are installed on your system.
1. **Environment Setup**:
- Navigate to the `docker` directory.
- Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`.
- Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options.
- **Optional (Recommended for upgrades)**:
You may use the environment synchronization tool to help keep your `.env` file aligned with the latest `.env.example` updates, while preserving your custom settings.
This is especially useful when upgrading Dify or managing a large, customized `.env` file.
- No copy step is required. The `dify-compose` wrappers create `.env` if it is missing and write a generated `SECRET_KEY` to it.
- When prompted on first run, press Enter to use the default deployment, or answer `y` to stop and edit `.env` first.
- Customize `.env` only when you need to override defaults from `.env.default`. Refer to `.env.example` for the full list of available variables.
- **Optional (for advanced deployments)**:
If you maintain a full `.env` file copied from `.env.example`, you may use the environment synchronization tool to keep it aligned with the latest `.env.example` updates while preserving your custom settings.
See the [Environment Variables Synchronization](#environment-variables-synchronization) section below.
1. **Running the Services**:
- Execute `docker compose up` from the `docker` directory to start the services.
- Execute `./dify-compose up -d` from the `docker` directory to start the services. On Windows PowerShell, run `.\dify-compose.ps1 up -d`.
- To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`.
1. **SSL Certificate Setup**:
- Refer `docker/certbot/README.md` to set up SSL certificates using Certbot.
@ -58,7 +58,13 @@ For users migrating from the `docker-legacy` setup:
1. **Data Migration**:
- Ensure that data from services like databases and caches is backed up and migrated appropriately to the new structure if necessary.
### Overview of `.env`
### Overview of `.env.default`, `.env`, and `.env.example`
- `.env.default` contains the minimal default configuration for Docker Compose deployments.
- `.env` contains the generated `SECRET_KEY` plus any local overrides.
- `.env.example` is the full reference for advanced configuration.
The `dify-compose` wrappers merge `.env.default` and `.env` into a temporary environment file, append paired internal service keys when needed, and remove the temporary file after Docker Compose starts.
#### Key Modules and Customization
@ -118,9 +124,11 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w
### Environment Variables Synchronization
When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.example`.
When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.default` or `.env.example`.
To help keep your existing `.env` file up to date **without losing your custom values**, an optional environment variables synchronization tool is provided.
If you use the default override-only workflow, review `.env.default` and add only the values you need to override to `.env`.
If you maintain a full `.env` file copied from `.env.example`, an optional environment variables synchronization tool is provided.
> This tool performs a **one-way synchronization** from `.env.example` to `.env`.
> Existing values in `.env` are never overwritten automatically.
@ -143,9 +151,9 @@ Before synchronization, the current `.env` file is saved to the `env-backup/` di
**When to use**
- After upgrading Dify to a newer version
- After upgrading Dify to a newer version with a full `.env` file
- When `.env.example` has been updated with new environment variables
- When managing a large or heavily customized `.env` file
- When managing a large or heavily customized `.env` file copied from `.env.example`
**Usage**

334
docker/dify-compose Executable file
View File

@ -0,0 +1,334 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
DEFAULT_ENV_FILE=".env.default"
USER_ENV_FILE=".env"
log() {
printf '%s\n' "$*" >&2
}
die() {
printf 'Error: %s\n' "$*" >&2
exit 1
}
detect_compose() {
if docker compose version >/dev/null 2>&1; then
COMPOSE_CMD=(docker compose)
return
fi
if command -v docker-compose >/dev/null 2>&1; then
COMPOSE_CMD=(docker-compose)
return
fi
die "Docker Compose is not available. Install Docker Compose, then run this command again."
}
generate_secret_key() {
if command -v openssl >/dev/null 2>&1; then
openssl rand -base64 42
return
fi
if command -v dd >/dev/null 2>&1 && command -v base64 >/dev/null 2>&1; then
dd if=/dev/urandom bs=42 count=1 2>/dev/null | base64 | tr -d '\n'
printf '\n'
return
fi
return 1
}
ensure_env_files() {
[[ -f "$DEFAULT_ENV_FILE" ]] || die "$DEFAULT_ENV_FILE is missing."
if [[ -f "$USER_ENV_FILE" ]]; then
return
fi
: >"$USER_ENV_FILE"
if [[ ! -t 0 ]]; then
log "Created $USER_ENV_FILE for local overrides."
return
fi
printf 'Created %s for local overrides.\n' "$USER_ENV_FILE"
printf 'Do you need a custom deployment now? (Most users can press Enter to skip.) [y/N] '
read -r answer
case "${answer:-}" in
y | Y | yes | YES | Yes)
cat <<'EOF'
Edit .env with the settings you want to override, using .env.example as the full reference.
Run ./dify-compose up -d again when you are ready.
EOF
exit 0
;;
esac
}
user_env_value() {
local key="$1"
awk -F= -v target="$key" '
/^[[:space:]]*#/ || !/=/{ next }
{
key = $1
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
if (key == target) {
value = substr($0, index($0, "=") + 1)
gsub(/^[[:space:]]+|[[:space:]]+$/, "", value)
if ((value ~ /^".*"$/) || (value ~ /^'\''.*'\''$/)) {
value = substr(value, 2, length(value) - 2)
}
result = value
}
}
END { print result }
' "$USER_ENV_FILE"
}
set_user_env_value() {
local key="$1"
local value="$2"
local temp_file
temp_file="$(mktemp "${TMPDIR:-/tmp}/dify-env.XXXXXX")"
awk -F= -v target="$key" -v replacement="$key=$value" '
BEGIN { replaced = 0 }
/^[[:space:]]*#/ || !/=/{ print; next }
{
key = $1
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
if (key == target) {
if (!replaced) {
print replacement
replaced = 1
}
next
}
print
}
END {
if (!replaced) {
print replacement
}
}
' "$USER_ENV_FILE" >"$temp_file"
mv "$temp_file" "$USER_ENV_FILE"
}
ensure_secret_key() {
local current_secret_key
local secret_key
current_secret_key="$(user_env_value SECRET_KEY)"
if [[ -n "$current_secret_key" ]]; then
return
fi
secret_key="$(generate_secret_key)" || die "Unable to generate SECRET_KEY. Install openssl or configure SECRET_KEY in .env."
set_user_env_value SECRET_KEY "$secret_key"
log "Generated SECRET_KEY in $USER_ENV_FILE."
}
env_value() {
local key="$1"
awk -F= -v target="$key" '
/^[[:space:]]*#/ || !/=/{ next }
{
key = $1
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
if (key == target) {
value = substr($0, index($0, "=") + 1)
gsub(/^[[:space:]]+|[[:space:]]+$/, "", value)
if ((value ~ /^".*"$/) || (value ~ /^'\''.*'\''$/)) {
value = substr(value, 2, length(value) - 2)
}
result = value
}
}
END { print result }
' "$DEFAULT_ENV_FILE" "$USER_ENV_FILE"
}
user_overrides() {
local key="$1"
grep -Eq "^[[:space:]]*${key}[[:space:]]*=" "$USER_ENV_FILE"
}
write_merged_env() {
awk '
function trim(s) {
sub(/^[[:space:]]+/, "", s)
sub(/[[:space:]]+$/, "", s)
return s
}
/^[[:space:]]*#/ || !/=/{ next }
{
key = $0
sub(/=.*/, "", key)
key = trim(key)
if (key == "") {
next
}
value = substr($0, index($0, "=") + 1)
value = trim(value)
if (!(key in seen)) {
order[++count] = key
seen[key] = 1
}
values[key] = value
}
END {
for (i = 1; i <= count; i++) {
key = order[i]
print key "=" values[key]
}
}
' "$DEFAULT_ENV_FILE" "$USER_ENV_FILE" >"$MERGED_ENV_FILE"
}
set_merged_env_value() {
local key="$1"
local value="$2"
local temp_file
temp_file="$(mktemp "${TMPDIR:-/tmp}/dify-compose-env.XXXXXX")"
awk -F= -v target="$key" -v replacement="$key=$value" '
BEGIN { replaced = 0 }
/^[[:space:]]*#/ || !/=/{ print; next }
{
key = $1
gsub(/^[[:space:]]+|[[:space:]]+$/, "", key)
if (key == target) {
if (!replaced) {
print replacement
replaced = 1
}
next
}
print
}
END {
if (!replaced) {
print replacement
}
}
' "$MERGED_ENV_FILE" >"$temp_file"
mv "$temp_file" "$MERGED_ENV_FILE"
}
set_if_not_overridden() {
local key="$1"
local value="$2"
if user_overrides "$key"; then
return
fi
set_merged_env_value "$key" "$value"
}
metadata_db_host() {
case "$1" in
mysql) printf 'db_mysql' ;;
postgresql | '') printf 'db_postgres' ;;
*) printf '%s' "$(env_value DB_HOST)" ;;
esac
}
metadata_db_port() {
case "$1" in
mysql) printf '3306' ;;
postgresql | '') printf '5432' ;;
*) printf '%s' "$(env_value DB_PORT)" ;;
esac
}
metadata_db_user() {
case "$1" in
mysql) printf 'root' ;;
postgresql | '') printf 'postgres' ;;
*) printf '%s' "$(env_value DB_USERNAME)" ;;
esac
}
build_merged_env() {
MERGED_ENV_FILE="$(mktemp "${TMPDIR:-/tmp}/dify-compose.XXXXXX")"
trap 'rm -f "$MERGED_ENV_FILE"' EXIT
write_merged_env
local db_type
local redis_host
local redis_port
local redis_username
local redis_password
local redis_auth
local code_execution_api_key
local weaviate_api_key
db_type="$(env_value DB_TYPE)"
set_if_not_overridden DB_HOST "$(metadata_db_host "$db_type")"
set_if_not_overridden DB_PORT "$(metadata_db_port "$db_type")"
set_if_not_overridden DB_USERNAME "$(metadata_db_user "$db_type")"
if ! user_overrides CELERY_BROKER_URL; then
redis_host="$(env_value REDIS_HOST)"
redis_port="$(env_value REDIS_PORT)"
redis_username="$(env_value REDIS_USERNAME)"
redis_password="$(env_value REDIS_PASSWORD)"
redis_auth=""
if [[ -n "$redis_username" && -n "$redis_password" ]]; then
redis_auth="${redis_username}:${redis_password}@"
elif [[ -n "$redis_password" ]]; then
redis_auth=":${redis_password}@"
elif [[ -n "$redis_username" ]]; then
redis_auth="${redis_username}@"
fi
set_merged_env_value CELERY_BROKER_URL "redis://${redis_auth}${redis_host:-redis}:${redis_port:-6379}/1"
fi
if ! user_overrides SANDBOX_API_KEY; then
code_execution_api_key="$(env_value CODE_EXECUTION_API_KEY)"
set_if_not_overridden SANDBOX_API_KEY "${code_execution_api_key:-dify-sandbox}"
fi
if ! user_overrides WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS; then
weaviate_api_key="$(env_value WEAVIATE_API_KEY)"
set_if_not_overridden WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS \
"${weaviate_api_key:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}"
fi
}
main() {
detect_compose
ensure_env_files
ensure_secret_key
build_merged_env
if [[ "$#" -eq 0 ]]; then
set -- up -d
fi
"${COMPOSE_CMD[@]}" --env-file "$MERGED_ENV_FILE" "$@"
}
main "$@"

317
docker/dify-compose.ps1 Normal file
View File

@ -0,0 +1,317 @@
$ErrorActionPreference = "Stop"
Set-StrictMode -Version Latest
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
Set-Location $ScriptDir
$DefaultEnvFile = ".env.default"
$UserEnvFile = ".env"
$MergedEnvFile = $null
$Utf8NoBom = New-Object System.Text.UTF8Encoding -ArgumentList $false
function Write-Info {
param([string]$Message)
[Console]::Error.WriteLine($Message)
}
function Fail {
param([string]$Message)
[Console]::Error.WriteLine("Error: $Message")
exit 1
}
function Test-CommandSuccess {
param([string[]]$Command)
try {
$Executable = $Command[0]
$CommandArgs = @()
if ($Command.Length -gt 1) {
$CommandArgs = @($Command[1..($Command.Length - 1)])
}
& $Executable @CommandArgs *> $null
return $LASTEXITCODE -eq 0
}
catch {
return $false
}
}
function Get-ComposeCommand {
if (Test-CommandSuccess @("docker", "compose", "version")) {
return @("docker", "compose")
}
if ((Get-Command "docker-compose" -ErrorAction SilentlyContinue) -and (Test-CommandSuccess @("docker-compose", "version"))) {
return @("docker-compose")
}
Fail "Docker Compose is not available. Install Docker Compose, then run this command again."
}
function New-SecretKey {
$Bytes = New-Object byte[] 42
$Generator = [System.Security.Cryptography.RandomNumberGenerator]::Create()
try {
$Generator.GetBytes($Bytes)
}
finally {
$Generator.Dispose()
}
return [Convert]::ToBase64String($Bytes)
}
function Ensure-EnvFiles {
if (-not (Test-Path $DefaultEnvFile -PathType Leaf)) {
Fail "$DefaultEnvFile is missing."
}
if (Test-Path $UserEnvFile -PathType Leaf) {
return
}
New-Item -ItemType File -Path $UserEnvFile | Out-Null
if ([Console]::IsInputRedirected) {
Write-Info "Created $UserEnvFile for local overrides."
return
}
Write-Info "Created $UserEnvFile for local overrides."
$Answer = Read-Host "Do you need a custom deployment now? (Most users can press Enter to skip.) [y/N]"
if ($Answer -match "^(y|yes)$") {
Write-Output "Edit .env with the settings you want to override, using .env.example as the full reference."
Write-Output "Run .\dify-compose.ps1 up -d again when you are ready."
exit 0
}
}
function Read-EnvFile {
param([string]$Path)
$Values = [ordered]@{}
if (-not (Test-Path $Path -PathType Leaf)) {
return $Values
}
foreach ($Line in Get-Content -Path $Path) {
if ($Line -match "^\s*#" -or $Line -notmatch "=") {
continue
}
$SeparatorIndex = $Line.IndexOf("=")
$Key = $Line.Substring(0, $SeparatorIndex).Trim()
$Value = $Line.Substring($SeparatorIndex + 1).Trim()
if (($Value.StartsWith('"') -and $Value.EndsWith('"')) -or ($Value.StartsWith("'") -and $Value.EndsWith("'"))) {
$Value = $Value.Substring(1, $Value.Length - 2)
}
if ($Key.Length -gt 0) {
$Values[$Key] = $Value
}
}
return $Values
}
function Set-UserEnvValue {
param(
[string]$Key,
[string]$Value
)
$Path = [string](Resolve-Path $UserEnvFile)
$Lines = [System.IO.File]::ReadAllLines($Path, [System.Text.Encoding]::UTF8)
$Output = New-Object System.Collections.Generic.List[string]
$Replaced = $false
foreach ($Line in $Lines) {
if ($Line -match "^\s*#" -or $Line -notmatch "=") {
$Output.Add($Line)
continue
}
$SeparatorIndex = $Line.IndexOf("=")
$CurrentKey = $Line.Substring(0, $SeparatorIndex).Trim()
if ($CurrentKey -eq $Key) {
if (-not $Replaced) {
$Output.Add("$Key=$Value")
$Replaced = $true
}
continue
}
$Output.Add($Line)
}
if (-not $Replaced) {
$Output.Add("$Key=$Value")
}
[System.IO.File]::WriteAllLines($Path, $Output, $Utf8NoBom)
}
function Ensure-SecretKey {
$Values = Read-EnvFile $UserEnvFile
if ($Values.Contains("SECRET_KEY") -and $Values["SECRET_KEY"]) {
return
}
Set-UserEnvValue "SECRET_KEY" (New-SecretKey)
Write-Info "Generated SECRET_KEY in $UserEnvFile."
}
function Merge-EnvValues {
$Values = [ordered]@{}
foreach ($Entry in (Read-EnvFile $DefaultEnvFile).GetEnumerator()) {
$Values[$Entry.Key] = $Entry.Value
}
foreach ($Entry in (Read-EnvFile $UserEnvFile).GetEnumerator()) {
$Values[$Entry.Key] = $Entry.Value
}
return $Values
}
function User-Overrides {
param([string]$Key)
if (-not (Test-Path $UserEnvFile -PathType Leaf)) {
return $false
}
return [bool](Select-String -Path $UserEnvFile -Pattern "^\s*$([regex]::Escape($Key))\s*=" -Quiet)
}
function Metadata-DbHost {
param([string]$DbType, $Values)
switch ($DbType) {
"mysql" { return "db_mysql" }
"postgresql" { return "db_postgres" }
"" { return "db_postgres" }
default { return $Values["DB_HOST"] }
}
}
function Metadata-DbPort {
param([string]$DbType, $Values)
switch ($DbType) {
"mysql" { return "3306" }
"postgresql" { return "5432" }
"" { return "5432" }
default { return $Values["DB_PORT"] }
}
}
function Metadata-DbUser {
param([string]$DbType, $Values)
switch ($DbType) {
"mysql" { return "root" }
"postgresql" { return "postgres" }
"" { return "postgres" }
default { return $Values["DB_USERNAME"] }
}
}
function Write-MergedEnv {
param($Values)
$Output = New-Object System.Collections.Generic.List[string]
foreach ($Entry in $Values.GetEnumerator()) {
$Output.Add("$($Entry.Key)=$($Entry.Value)")
}
[System.IO.File]::WriteAllLines($MergedEnvFile, $Output, $Utf8NoBom)
}
function Build-MergedEnv {
$Values = Merge-EnvValues
$script:MergedEnvFile = [System.IO.Path]::GetTempFileName()
$DbType = if ($Values.Contains("DB_TYPE")) { $Values["DB_TYPE"] } else { "postgresql" }
if (-not (User-Overrides "DB_HOST")) {
$Values["DB_HOST"] = Metadata-DbHost $DbType $Values
}
if (-not (User-Overrides "DB_PORT")) {
$Values["DB_PORT"] = Metadata-DbPort $DbType $Values
}
if (-not (User-Overrides "DB_USERNAME")) {
$Values["DB_USERNAME"] = Metadata-DbUser $DbType $Values
}
if (-not (User-Overrides "CELERY_BROKER_URL")) {
$RedisHost = if ($Values.Contains("REDIS_HOST") -and $Values["REDIS_HOST"]) { $Values["REDIS_HOST"] } else { "redis" }
$RedisPort = if ($Values.Contains("REDIS_PORT") -and $Values["REDIS_PORT"]) { $Values["REDIS_PORT"] } else { "6379" }
$RedisUsername = if ($Values.Contains("REDIS_USERNAME")) { $Values["REDIS_USERNAME"] } else { "" }
$RedisPassword = if ($Values.Contains("REDIS_PASSWORD")) { $Values["REDIS_PASSWORD"] } else { "" }
$RedisAuth = ""
if ($RedisUsername -and $RedisPassword) {
$RedisAuth = "${RedisUsername}:${RedisPassword}@"
}
elseif ($RedisPassword) {
$RedisAuth = ":${RedisPassword}@"
}
elseif ($RedisUsername) {
$RedisAuth = "${RedisUsername}@"
}
$Values["CELERY_BROKER_URL"] = "redis://$RedisAuth${RedisHost}:${RedisPort}/1"
}
if (-not (User-Overrides "SANDBOX_API_KEY")) {
$CodeExecutionApiKey = if ($Values.Contains("CODE_EXECUTION_API_KEY") -and $Values["CODE_EXECUTION_API_KEY"]) { $Values["CODE_EXECUTION_API_KEY"] } else { "dify-sandbox" }
$Values["SANDBOX_API_KEY"] = $CodeExecutionApiKey
}
if (-not (User-Overrides "WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS")) {
$WeaviateApiKey = if ($Values.Contains("WEAVIATE_API_KEY") -and $Values["WEAVIATE_API_KEY"]) { $Values["WEAVIATE_API_KEY"] } else { "WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih" }
$Values["WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS"] = $WeaviateApiKey
}
Write-MergedEnv $Values
}
$ComposeCommand = Get-ComposeCommand
try {
Ensure-EnvFiles
Ensure-SecretKey
Build-MergedEnv
$ComposeArgs = @($args)
if ($ComposeArgs.Count -eq 0) {
$ComposeArgs = @("up", "-d")
}
$ComposeCommandArgs = @()
if ($ComposeCommand.Length -gt 1) {
$ComposeCommandArgs = @($ComposeCommand[1..($ComposeCommand.Length - 1)])
}
$ComposeExecutable = $ComposeCommand[0]
& $ComposeExecutable @ComposeCommandArgs --env-file $MergedEnvFile @ComposeArgs
exit $LASTEXITCODE
}
finally {
if ($MergedEnvFile -and (Test-Path $MergedEnvFile -PathType Leaf)) {
Remove-Item -Force $MergedEnvFile
}
}

View File

@ -170,8 +170,8 @@ services:
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100}
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
@ -402,8 +402,8 @@ services:
- ./certbot/update-cert.template.txt:/update-cert.template.txt
- ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh
environment:
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_EMAIL=${CERTBOT_EMAIL:-}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ["/docker-entrypoint.sh"]
command: ["tail", "-f", "/dev/null"]

View File

@ -441,10 +441,10 @@ x-shared-env: &shared-api-worker-env
NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-}
NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-}
NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-}
MAIL_TYPE: ${MAIL_TYPE:-resend}
MAIL_TYPE: ${MAIL_TYPE:-}
MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-}
RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com}
RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key}
RESEND_API_KEY: ${RESEND_API_KEY:-}
SMTP_SERVER: ${SMTP_SERVER:-}
SMTP_PORT: ${SMTP_PORT:-465}
SMTP_USERNAME: ${SMTP_USERNAME:-}
@ -586,8 +586,8 @@ x-shared-env: &shared-api-worker-env
NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s}
NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s}
NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false}
CERTBOT_EMAIL: ${CERTBOT_EMAIL:-your_email@example.com}
CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-your_domain.com}
CERTBOT_EMAIL: ${CERTBOT_EMAIL:-}
CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-}
CERTBOT_OPTIONS: ${CERTBOT_OPTIONS:-}
SSRF_HTTP_PORT: ${SSRF_HTTP_PORT:-3128}
SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid}
@ -894,8 +894,8 @@ services:
ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false}
MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai}
MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-}
TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10}
INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000}
LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100}
MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10}
MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10}
@ -1126,8 +1126,8 @@ services:
- ./certbot/update-cert.template.txt:/update-cert.template.txt
- ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh
environment:
- CERTBOT_EMAIL=${CERTBOT_EMAIL}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN}
- CERTBOT_EMAIL=${CERTBOT_EMAIL:-}
- CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-}
- CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-}
entrypoint: ["/docker-entrypoint.sh"]
command: ["tail", "-f", "/dev/null"]

View File

@ -315,20 +315,12 @@
"count": 4
}
},
"web/app/components/app/configuration/config-var/config-modal/type-select.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/app/configuration/config-var/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/app/configuration/config-var/select-var-type.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
@ -363,9 +355,6 @@
}
},
"web/app/components/app/configuration/config/assistant-type-picker/index.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}
@ -401,11 +390,6 @@
"count": 1
}
},
"web/app/components/app/configuration/config/automatic/version-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/app/configuration/config/code-generator/get-code-generator-res.tsx": {
"no-restricted-imports": {
"count": 1
@ -774,11 +758,6 @@
"count": 1
}
},
"web/app/components/base/chat/chat-with-history/sidebar/rename-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/base/chat/chat/answer/agent-content.tsx": {
"style/multiline-ternary": {
"count": 2
@ -800,11 +779,6 @@
"count": 1
}
},
"web/app/components/base/chat/chat/answer/operation.tsx": {
"no-restricted-imports": {
"count": 2
}
},
"web/app/components/base/chat/chat/answer/workflow-process.tsx": {
"react/set-state-in-effect": {
"count": 1
@ -1055,14 +1029,6 @@
"count": 3
}
},
"web/app/components/base/form/components/base/base-field.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 3
}
},
"web/app/components/base/form/components/base/base-form.tsx": {
"ts/no-explicit-any": {
"count": 6
@ -1589,14 +1555,6 @@
"count": 1
}
},
"web/app/components/base/modal/modal.stories.tsx": {
"no-console": {
"count": 4
},
"react/set-state-in-effect": {
"count": 1
}
},
"web/app/components/base/new-audio-button/index.tsx": {
"ts/no-explicit-any": {
"count": 1
@ -2116,11 +2074,6 @@
"count": 1
}
},
"web/app/components/datasets/documents/create-from-pipeline/data-source/base/credential-selector/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/datasets/documents/create-from-pipeline/data-source/online-documents/index.tsx": {
"ts/no-explicit-any": {
"count": 1
@ -2389,11 +2342,6 @@
"count": 1
}
},
"web/app/components/datasets/settings/index-method/index.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/develop/code.tsx": {
"ts/no-empty-object-type": {
"count": 1
@ -2579,11 +2527,8 @@
"erasable-syntax-only/enums": {
"count": 1
},
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 3
"count": 2
}
},
"web/app/components/header/account-setting/model-provider-page/declarations.ts": {
@ -2612,11 +2557,6 @@
"count": 1
}
},
"web/app/components/header/account-setting/model-provider-page/model-auth/credential-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts": {
"no-barrel-files/no-barrel-files": {
"count": 6
@ -2912,44 +2852,11 @@
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/hooks/use-oauth-client-state.ts": {
"erasable-syntax-only/enums": {
"count": 2
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": {
"no-barrel-files/no-barrel-files": {
"count": 3
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/oauth-client.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/create/types.ts": {
"erasable-syntax-only/enums": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/edit/apikey-edit-modal.tsx": {
"erasable-syntax-only/enums": {
"count": 1
},
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/edit/manual-edit-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/edit/oauth-edit-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": {
"no-barrel-files/no-barrel-files": {
"count": 2
@ -2978,11 +2885,6 @@
"count": 7
}
},
"web/app/components/plugins/plugin-detail-panel/tool-selector/components/schema-modal.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-detail-panel/tool-selector/components/tool-item.tsx": {
"no-restricted-imports": {
"count": 1
@ -2998,11 +2900,6 @@
"count": 5
}
},
"web/app/components/plugins/plugin-item/action.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/plugins/plugin-item/index.tsx": {
"no-restricted-imports": {
"count": 1
@ -3681,11 +3578,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/error-handle/error-handle-type-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/error-handle/types.ts": {
"erasable-syntax-only/enums": {
"count": 1
@ -3782,11 +3674,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/variable/var-type-picker.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts": {
"react/no-unnecessary-use-prefix": {
"count": 2
@ -4106,16 +3993,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/if-else/components/condition-list/condition-operator.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/if-else/components/condition-number-input.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/if-else/default.ts": {
"ts/no-explicit-any": {
"count": 1
@ -4151,11 +4028,6 @@
"count": 6
}
},
"web/app/components/workflow/nodes/knowledge-base/components/chunk-structure/selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/hooks.tsx": {
"ts/no-explicit-any": {
"count": 4
@ -4199,11 +4071,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/metadata-filter-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/knowledge-retrieval/default.ts": {
"ts/no-explicit-any": {
"count": 1
@ -4285,11 +4152,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/edit-card/type-selector.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/hooks.ts": {
"ts/no-explicit-any": {
"count": 1
@ -4341,16 +4203,6 @@
"count": 1
}
},
"web/app/components/workflow/nodes/loop/components/condition-list/condition-operator.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/loop/components/condition-number-input.tsx": {
"no-restricted-imports": {
"count": 1
}
},
"web/app/components/workflow/nodes/loop/components/loop-variables/form-item.tsx": {
"ts/no-explicit-any": {
"count": 3
@ -4522,9 +4374,6 @@
}
},
"web/app/components/workflow/nodes/tool/components/tool-form/item.tsx": {
"no-restricted-imports": {
"count": 1
},
"ts/no-explicit-any": {
"count": 1
}

View File

@ -4,6 +4,17 @@ import antfu, { GLOB_MARKDOWN } from '@antfu/eslint-config'
import md from 'eslint-markdown'
import markdownPreferences from 'eslint-plugin-markdown-preferences'
const GENERATED_IGNORES = [
'**/storybook-static/',
'**/.next/',
'web/next/',
'web/next-env.d.ts',
'**/dist/',
'**/coverage/',
'e2e/.auth/',
'e2e/cucumber-report/',
]
export default antfu(
{
ignores: original => [
@ -15,6 +26,7 @@ export default antfu(
'!package.json',
'!pnpm-workspace.yaml',
'!vite.config.ts',
...GENERATED_IGNORES,
...original,
],
typescript: {

View File

@ -2,7 +2,7 @@
"name": "dify",
"type": "module",
"private": true,
"packageManager": "pnpm@11.0.0",
"packageManager": "pnpm@11.0.6",
"engines": {
"node": "^22.22.1"
},

View File

@ -56,4 +56,28 @@ The Figma design system uses `--radius/*` tokens whose scale is **offset by one
- When the Figma MCP returns `rounded-[var(--radius/sm, 6px)]`, convert it to the standard Tailwind class from the table above (e.g. `rounded-md`).
- For values without a standard Tailwind equivalent (10px, 20px, 28px), use arbitrary values like `rounded-[10px]`.
## Search / Picker Primitive Selection: Autocomplete vs Combobox vs Select
Pick by whether the user is entering free-form text, choosing a remembered value, or selecting from a closed list.
Base UI decision rules:
- [Autocomplete docs]: use `Combobox` instead of `Autocomplete` if the selection should be remembered and the input value cannot be custom.
- [Combobox docs]: do not use `Combobox` for simple search widgets that require unrestricted text entry; use `Autocomplete` instead.
Apply this split in Dify UI:
- `Autocomplete` — free-form text input with optional suggestions or completions. The input value may be custom and does not necessarily become a selected option. Use for search boxes, command-style suggestions, tag suggestions, and async text completion.
- `Combobox` — searchable picker whose value is one or more selected items from a collection. The chosen value is remembered by the root, and free-form text is not the final value. Use for model pickers, user pickers, dataset/document pickers, and multi-select chips.
- `Select` — closed-list picker without text entry. Use when the option set is small or already scannable and filtering is unnecessary.
Composition rules:
- Keep Base UI primitive semantics visible in the public API. Export compound parts such as `ComboboxInputGroup`, `ComboboxInput`, `ComboboxContent`, `ComboboxList`, `ComboboxItem`, and `ComboboxItemIndicator` instead of wrapping them into one business component.
- For `Combobox` multiple selection, follow the official chips pattern: `ComboboxInputGroup` contains `ComboboxChips`, `ComboboxValue` renders `ComboboxChip` items, and `ComboboxInput` remains inside the chips row. Chips should wrap and let the input group grow vertically instead of forcing horizontal overflow.
- Content primitives must own their Base UI `Portal` and use `z-1002` on `Positioner`, matching the overlay contract in `README.md`.
- Use `w-(--anchor-width)` with viewport-aware max-width for `Autocomplete` and `Combobox` popups. Do not add `min-w-(--anchor-width)` when it would defeat available-width clamping.
[Autocomplete docs]: https://base-ui.com/react/components/autocomplete.md#usage-guidelines
[Combobox docs]: https://base-ui.com/react/components/combobox.md#usage-guidelines
[docs]: https://base-ui.com/react/components/tooltip#infotips

View File

@ -36,12 +36,12 @@ Importing from `@langgenius/dify-ui` (no subpath) is intentionally not supported
## Primitives
| Category | Subpath | Notes |
| -------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------- |
| Overlay | `./alert-dialog`, `./context-menu`, `./dialog`, `./dropdown-menu`, `./popover`, `./select`, `./toast`, `./tooltip` | Portalled. See [Overlay & portal contract] below. |
| Form | `./number-field`, `./slider`, `./switch` | Controlled / uncontrolled per Base UI defaults. |
| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. |
| Media | `./avatar`, `./button` | Button exposes `cva` variants. |
| Category | Subpath | Notes |
| -------- | -------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------- |
| Overlay | `./alert-dialog`, `./autocomplete`, `./combobox`, `./context-menu`, `./dialog`, `./dropdown-menu`, `./popover`, `./select`, `./toast`, `./tooltip` | Portalled. See [Overlay & portal contract] below. |
| Form | `./autocomplete`, `./combobox`, `./number-field`, `./slider`, `./switch` | Controlled / uncontrolled per Base UI defaults. |
| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. |
| Media | `./avatar`, `./button` | Button exposes `cva` variants. |
Utilities:
@ -65,7 +65,7 @@ If a consumer uses Dify UI source files through the workspace, add an explicit s
## Overlay & portal contract
All overlay primitives (`dialog`, `alert-dialog`, `popover`, `dropdown-menu`, `context-menu`, `select`, `tooltip`, `toast`) render their content inside a [Base UI Portal] attached to `document.body`. This is the Base UI default — see the upstream [Portals][Base UI Portal] docs for the underlying behavior. Consumers **do not** need to wrap anything in a portal manually.
All overlay primitives (`dialog`, `alert-dialog`, `autocomplete`, `combobox`, `popover`, `dropdown-menu`, `context-menu`, `select`, `tooltip`, `toast`) render their content inside a [Base UI Portal] attached to `document.body`. This is the Base UI default — see the upstream [Portals][Base UI Portal] docs for the underlying behavior. Consumers **do not** need to wrap anything in a portal manually.
### Root isolation requirement
@ -83,14 +83,14 @@ Equivalent: any root element with `isolation: isolate` in CSS. Without it, overl
Every overlay primitive uses a single, shared z-index. Do **not** override it at call sites.
| Layer | z-index | Where |
| ----------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- |
| Overlays (Dialog, AlertDialog, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop |
| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. |
| Layer | z-index | Where |
| ----------------------------------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- |
| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop |
| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. |
Rationale: during Dify's migration from legacy `portal-to-follow-elem` / `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins.
Rationale: during Dify's migration from legacy `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins.
See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history and the remaining legacy allowlist. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`.
See `[web/docs/overlay-migration.md](../../web/docs/overlay-migration.md)` for the Dify-web migration history. Once the legacy overlays are gone, the values in this table can drop back to `z-50` / `z-51`.
### Rules

View File

@ -13,6 +13,10 @@
"types": "./src/alert-dialog/index.tsx",
"import": "./src/alert-dialog/index.tsx"
},
"./autocomplete": {
"types": "./src/autocomplete/index.tsx",
"import": "./src/autocomplete/index.tsx"
},
"./avatar": {
"types": "./src/avatar/index.tsx",
"import": "./src/avatar/index.tsx"
@ -21,6 +25,10 @@
"types": "./src/button/index.tsx",
"import": "./src/button/index.tsx"
},
"./combobox": {
"types": "./src/combobox/index.tsx",
"import": "./src/combobox/index.tsx"
},
"./context-menu": {
"types": "./src/context-menu/index.tsx",
"import": "./src/context-menu/index.tsx"
@ -103,6 +111,7 @@
"@storybook/addon-themes": "catalog:",
"@storybook/react-vite": "catalog:",
"@tailwindcss/vite": "catalog:",
"@tanstack/react-virtual": "catalog:",
"@types/react": "catalog:",
"@types/react-dom": "catalog:",
"@typescript/native-preview": "catalog:",

View File

@ -0,0 +1,252 @@
import type { ReactNode } from 'react'
import { render } from 'vitest-browser-react'
import {
Autocomplete,
AutocompleteClear,
AutocompleteContent,
AutocompleteEmpty,
AutocompleteGroup,
AutocompleteInput,
AutocompleteInputGroup,
AutocompleteItem,
AutocompleteItemIndicator,
AutocompleteItemText,
AutocompleteLabel,
AutocompleteList,
AutocompleteSeparator,
AutocompleteStatus,
AutocompleteTrigger,
} from '../index'
const renderWithSafeViewport = (ui: ReactNode) => render(
<div style={{ minHeight: '100vh', minWidth: '100vw', padding: '240px' }}>
{ui}
</div>,
)
const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement
const renderAutocomplete = ({
children,
open = false,
defaultValue = 'workflow',
}: {
children?: ReactNode
open?: boolean
defaultValue?: string
} = {}) => renderWithSafeViewport(
<Autocomplete open={open} defaultValue={defaultValue} items={['workflow', 'dataset']}>
{children ?? (
<>
<AutocompleteInputGroup data-testid="input-group">
<AutocompleteInput aria-label="Search suggestions" data-testid="input" />
<AutocompleteClear data-testid="clear" />
<AutocompleteTrigger data-testid="trigger" />
</AutocompleteInputGroup>
<AutocompleteContent
positionerProps={{
'role': 'group',
'aria-label': 'autocomplete positioner',
}}
popupProps={{
'role': 'dialog',
'aria-label': 'autocomplete popup',
}}
>
<AutocompleteStatus data-testid="status">2 suggestions</AutocompleteStatus>
<AutocompleteList role="listbox" aria-label="autocomplete list" data-testid="list">
<AutocompleteItem value="workflow">
<AutocompleteItemText>Workflow</AutocompleteItemText>
<AutocompleteItemIndicator />
</AutocompleteItem>
<AutocompleteItem value="dataset">
<AutocompleteItemText>Dataset</AutocompleteItemText>
</AutocompleteItem>
</AutocompleteList>
<AutocompleteEmpty data-testid="empty">No suggestions</AutocompleteEmpty>
</AutocompleteContent>
</>
)}
</Autocomplete>,
)
describe('Autocomplete wrappers', () => {
describe('Input group and input', () => {
it('should apply medium input group and input classes by default', async () => {
const screen = await renderAutocomplete()
await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-lg')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('px-3')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('system-sm-regular')
})
it('should apply large input group and input classes when large size is provided', async () => {
const screen = await renderAutocomplete({
children: (
<AutocompleteInputGroup size="large" data-testid="input-group">
<AutocompleteInput size="large" aria-label="Search suggestions" data-testid="input" />
</AutocompleteInputGroup>
),
})
await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-[10px]')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('px-4')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('system-md-regular')
})
it('should set input defaults and forward passthrough props', async () => {
const screen = await renderAutocomplete({
children: (
<AutocompleteInputGroup>
<AutocompleteInput
aria-label="Search suggestions"
className="custom-input"
placeholder="Find a resource"
required
/>
</AutocompleteInputGroup>
),
})
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('autocomplete', 'off')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('type', 'text')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('placeholder', 'Find a resource')
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toBeRequired()
await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('custom-input')
})
})
describe('Controls', () => {
it('should provide fallback aria labels and decorative icons when labels are omitted', async () => {
const screen = await renderAutocomplete()
await expect.element(screen.getByRole('button', { name: 'Clear autocomplete' })).toHaveAttribute('type', 'button')
await expect.element(screen.getByRole('button', { name: 'Open autocomplete suggestions' })).toHaveAttribute('type', 'button')
expect(screen.getByRole('button', { name: 'Clear autocomplete' }).element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true')
expect(screen.getByRole('button', { name: 'Open autocomplete suggestions' }).element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true')
})
it('should preserve explicit labels and custom children', async () => {
const screen = await renderAutocomplete({
children: (
<AutocompleteInputGroup>
<AutocompleteInput aria-label="Search suggestions" />
<AutocompleteClear aria-label="Reset search">
<span data-testid="custom-clear">reset</span>
</AutocompleteClear>
<AutocompleteTrigger aria-label="Show suggestions">
<span data-testid="custom-trigger">open</span>
</AutocompleteTrigger>
</AutocompleteInputGroup>
),
})
expect(screen.getByRole('button', { name: 'Reset search' }).element()).toContainElement(screen.getByTestId('custom-clear').element())
expect(screen.getByRole('button', { name: 'Show suggestions' }).element()).toContainElement(screen.getByTestId('custom-trigger').element())
expect(screen.getByRole('button', { name: 'Reset search' }).element().querySelector('.i-ri-close-line')).not.toBeInTheDocument()
expect(screen.getByRole('button', { name: 'Show suggestions' }).element().querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument()
})
it('should rely on aria-labelledby when provided instead of injecting fallback labels', async () => {
const screen = await renderAutocomplete({
children: (
<>
<span id="clear-label">Clear from label</span>
<span id="trigger-label">Trigger from label</span>
<AutocompleteInputGroup>
<AutocompleteInput aria-label="Search suggestions" />
<AutocompleteClear aria-labelledby="clear-label" />
<AutocompleteTrigger aria-labelledby="trigger-label" />
</AutocompleteInputGroup>
</>
),
})
await expect.element(screen.getByRole('button', { name: 'Clear from label' })).not.toHaveAttribute('aria-label')
await expect.element(screen.getByRole('button', { name: 'Trigger from label' })).not.toHaveAttribute('aria-label')
})
})
describe('Content and options', () => {
it('should use default overlay placement and Dify popup classes', async () => {
const screen = await renderAutocomplete({ open: true })
await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'bottom')
await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-align', 'start')
await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-1002')
await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('rounded-xl')
await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('w-(--anchor-width)')
await expect.element(screen.getByRole('listbox', { name: 'autocomplete list' })).toHaveClass('scroll-py-1')
})
it('should apply custom placement side and passthrough popup props', async () => {
const onPopupClick = vi.fn()
const screen = await renderWithSafeViewport(
<Autocomplete open defaultValue="workflow" items={['workflow']}>
<AutocompleteInputGroup>
<AutocompleteInput aria-label="Search suggestions" />
</AutocompleteInputGroup>
<AutocompleteContent
placement="top-end"
sideOffset={12}
alignOffset={6}
positionerProps={{ 'role': 'group', 'aria-label': 'autocomplete positioner' }}
popupProps={{
'role': 'dialog',
'aria-label': 'autocomplete popup',
'onClick': onPopupClick,
}}
>
<AutocompleteList role="listbox" aria-label="autocomplete list">
<AutocompleteItem value="workflow">
<AutocompleteItemText>Workflow</AutocompleteItemText>
</AutocompleteItem>
</AutocompleteList>
</AutocompleteContent>
</Autocomplete>,
)
asHTMLElement(screen.getByRole('dialog', { name: 'autocomplete popup' }).element()).click()
await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'top')
expect(onPopupClick).toHaveBeenCalledTimes(1)
})
it('should render item text indicator status and empty wrappers with design classes', async () => {
const screen = await renderAutocomplete({ open: true })
await expect.element(screen.getByText('Workflow')).toHaveClass('system-sm-medium')
await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary')
await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular')
expect(screen.getByText('Workflow').element().parentElement?.querySelector('.i-ri-arrow-right-line')).toHaveAttribute('aria-hidden', 'true')
})
it('should forward custom classes to label separator item text and indicator', async () => {
const screen = await renderWithSafeViewport(
<Autocomplete open defaultValue="workflow" items={['workflow']}>
<AutocompleteInputGroup>
<AutocompleteInput aria-label="Search suggestions" />
</AutocompleteInputGroup>
<AutocompleteContent popupProps={{ 'role': 'dialog', 'aria-label': 'autocomplete popup' }}>
<AutocompleteList role="listbox" aria-label="autocomplete list">
<AutocompleteGroup items={['workflow']}>
<AutocompleteLabel className="custom-label">Resources</AutocompleteLabel>
<AutocompleteSeparator className="custom-separator" data-testid="separator" />
<AutocompleteItem value="workflow" className="custom-item">
<AutocompleteItemText className="custom-text">Workflow</AutocompleteItemText>
<AutocompleteItemIndicator className="custom-indicator" data-testid="indicator" />
</AutocompleteItem>
</AutocompleteGroup>
</AutocompleteList>
</AutocompleteContent>
</Autocomplete>,
)
await expect.element(screen.getByText('Resources')).toHaveClass('custom-label')
await expect.element(screen.getByTestId('separator')).toHaveClass('custom-separator')
await expect.element(screen.getByRole('option', { name: 'Workflow' })).toHaveClass('custom-item')
await expect.element(screen.getByText('Workflow')).toHaveClass('custom-text')
await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator')
})
})
})

View File

@ -0,0 +1,721 @@
import type { Meta, StoryObj } from '@storybook/react-vite'
import type { Virtualizer } from '@tanstack/react-virtual'
import type { RefObject } from 'react'
import { useVirtualizer } from '@tanstack/react-virtual'
import { useEffect, useMemo, useRef, useState } from 'react'
import {
Autocomplete,
AutocompleteClear,
AutocompleteCollection,
AutocompleteContent,
AutocompleteEmpty,
AutocompleteGroup,
AutocompleteInput,
AutocompleteInputGroup,
AutocompleteItem,
AutocompleteItemText,
AutocompleteLabel,
AutocompleteList,
AutocompleteSeparator,
AutocompleteStatus,
AutocompleteTrigger,
useAutocompleteFilter,
useAutocompleteFilteredItems,
} from '.'
import { cn } from '../cn'
type Suggestion = {
value: string
label: string
description?: string
icon?: string
meta?: string
}
type SuggestionGroup = {
label: string
items: Suggestion[]
}
const inputWidth = 'w-80'
type StoryVirtualizer = Virtualizer<HTMLDivElement, Element>
const scrollHighlightedVirtualItem = (
item: unknown,
{
reason,
index,
}: {
reason: 'keyboard' | 'pointer' | 'none'
index: number
},
virtualizer: StoryVirtualizer | null,
) => {
if (!item || !virtualizer)
return
const isStart = index === 0
const isEnd = index === virtualizer.options.count - 1
const shouldScroll = reason === 'none' || (reason === 'keyboard' && (isStart || isEnd))
if (shouldScroll) {
queueMicrotask(() => {
virtualizer.scrollToIndex(index, { align: isEnd ? 'start' : 'end' })
})
}
}
const tagSuggestions: Suggestion[] = [
{ value: 'feature', label: 'feature', description: 'Product work and launch notes' },
{ value: 'fix', label: 'fix', description: 'Bug fixes and regressions' },
{ value: 'docs', label: 'docs', description: 'Documentation updates' },
{ value: 'internal', label: 'internal', description: 'Workspace-only notes' },
{ value: 'mobile', label: 'mobile', description: 'Mobile app issues' },
{ value: 'component: autocomplete', label: 'component: autocomplete', description: 'Base UI primitive wrapper' },
{ value: 'component: combobox', label: 'component: combobox', description: 'Filterable predefined selection' },
{ value: 'component: select', label: 'component: select', description: 'Compact predefined selection' },
]
const promptCompletions: Suggestion[] = [
{ value: 'summarize this conversation', label: 'summarize this conversation' },
{ value: 'summarize this dataset with citations', label: 'summarize this dataset with citations' },
{ value: 'summarize this workflow run for an operator', label: 'summarize this workflow run for an operator' },
{ value: 'summarize this support ticket in 3 bullets', label: 'summarize this support ticket in 3 bullets' },
]
const workflowSuggestions: Suggestion[] = [
{ value: 'http-request', label: 'HTTP Request', description: 'Call an external API', icon: 'i-ri-global-line', meta: 'Tool' },
{ value: 'knowledge-retrieval', label: 'Knowledge Retrieval', description: 'Search configured datasets', icon: 'i-ri-database-2-line', meta: 'Tool' },
{ value: 'code-execution', label: 'Code Execution', description: 'Run sandboxed snippets', icon: 'i-ri-code-s-slash-line', meta: 'Tool' },
{ value: 'template-transform', label: 'Template Transform', description: 'Compose variables into output', icon: 'i-ri-braces-line', meta: 'Tool' },
{ value: 'question-classifier', label: 'Question Classifier', description: 'Route by intent', icon: 'i-ri-git-branch-line', meta: 'Tool' },
{ value: 'parameter-extractor', label: 'Parameter Extractor', description: 'Extract typed values', icon: 'i-ri-list-check-3', meta: 'Tool' },
{ value: 'answer-node', label: 'Answer Node', description: 'Return a final assistant answer', icon: 'i-ri-message-3-line', meta: 'Node' },
{ value: 'iteration-node', label: 'Iteration Node', description: 'Run a loop over array items', icon: 'i-ri-repeat-line', meta: 'Node' },
{ value: 'variable-assigner', label: 'Variable Assigner', description: 'Persist intermediate state', icon: 'i-ri-pencil-ruler-2-line', meta: 'Node' },
]
const groupedSuggestions: SuggestionGroup[] = [
{
label: 'Tags',
items: tagSuggestions.slice(0, 5),
},
{
label: 'Workflow Suggestions',
items: workflowSuggestions.slice(0, 5),
},
{
label: 'Prompt Starters',
items: promptCompletions.slice(0, 3),
},
]
const commandGroups: SuggestionGroup[] = [
{
label: 'App',
items: [
{ value: '/run', label: 'Run workflow', description: 'Execute the current draft', icon: 'i-ri-play-circle-line' },
{ value: '/publish', label: 'Publish app', description: 'Ship the current configuration', icon: 'i-ri-upload-cloud-2-line' },
{ value: '/trace', label: 'Open trace', description: 'Inspect the latest workflow run', icon: 'i-ri-route-line' },
],
},
{
label: 'Workspace',
items: [
{ value: '/dataset', label: 'Search datasets', description: 'Find knowledge attached to this app', icon: 'i-ri-database-line' },
{ value: '/members', label: 'Invite members', description: 'Open workspace access settings', icon: 'i-ri-user-add-line' },
{ value: '/usage', label: 'View usage', description: 'Open model and workflow usage', icon: 'i-ri-bar-chart-line' },
],
},
]
const remoteSuggestions: Suggestion[] = [
{ value: 'agent-builder', label: 'Agent Builder', description: 'Workspace app' },
{ value: 'agent-observability', label: 'Agent Observability', description: 'Dataset' },
{ value: 'agent-routing-dataset', label: 'Agent Routing Dataset', description: 'Knowledge source' },
]
const virtualizedSuggestions: Suggestion[] = Array.from({ length: 1000 }, (_, index) => {
const family = ['workflow', 'dataset', 'prompt', 'tool'][index % 4]!
const number = new Intl.NumberFormat('en-US', {
minimumIntegerDigits: 4,
}).format(index + 1)
return {
value: `${family}-${index + 1}`,
label: `${family} suggestion ${number}`,
description: `Free-form autocomplete result from ${family} search`,
icon: family === 'dataset'
? 'i-ri-database-2-line'
: family === 'prompt'
? 'i-ri-text-snippet'
: family === 'tool'
? 'i-ri-tools-line'
: 'i-ri-flow-chart',
meta: family,
}
})
const getSuggestionLabel = (item: Suggestion) => item.label
const SuggestionItem = ({
item,
index,
dense,
}: {
item: Suggestion
index?: number
dense?: boolean
}) => (
<AutocompleteItem value={item} index={index}>
{item.icon && <span className={cn(item.icon, 'size-4 shrink-0 text-text-tertiary')} aria-hidden="true" />}
<div className="flex min-w-0 grow flex-col">
<AutocompleteItemText className="px-0">{item.label}</AutocompleteItemText>
{!dense && item.description && (
<span className="truncate system-xs-regular text-text-tertiary">{item.description}</span>
)}
</div>
{item.meta && (
<span className="shrink-0 rounded-md bg-components-badge-bg-dimm px-1.5 py-0.5 system-2xs-medium text-text-tertiary">
{item.meta}
</span>
)}
</AutocompleteItem>
)
const TagSuggestionItem = ({
item,
index,
}: {
item: Suggestion
index?: number
}) => (
<AutocompleteItem value={item} index={index}>
<AutocompleteItemText className="px-0">{item.label}</AutocompleteItemText>
{item.description && <span className="ml-auto max-w-36 truncate system-xs-regular text-text-tertiary">{item.description}</span>}
</AutocompleteItem>
)
const BasicTagAutocomplete = ({
size = 'medium',
}: {
size?: 'small' | 'medium' | 'large'
}) => (
<Autocomplete
items={tagSuggestions}
itemToStringValue={getSuggestionLabel}
openOnInputClick
>
<AutocompleteInputGroup size={size}>
<span className="i-ri-search-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput size={size} placeholder="Search tags or type a new one…" aria-label="Search tags or type a new one" />
<AutocompleteClear size={size} />
<AutocompleteTrigger size={size} />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<TagSuggestionItem key={item.value} item={item} index={index} />
)}
</AutocompleteList>
<AutocompleteEmpty>No tag suggestion. Keep the typed value.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
)
const GroupedSuggestionList = () => {
const groups = useAutocompleteFilteredItems<SuggestionGroup>()
return (
<AutocompleteList>
{groups.map((group, groupIndex) => (
<AutocompleteGroup key={group.label} items={group.items}>
{groupIndex > 0 && <AutocompleteSeparator />}
<AutocompleteLabel>{group.label}</AutocompleteLabel>
<AutocompleteCollection>
{(item: Suggestion) => (
<SuggestionItem key={item.value} item={item} />
)}
</AutocompleteCollection>
</AutocompleteGroup>
))}
</AutocompleteList>
)
}
const CommandPaletteList = () => {
const groups = useAutocompleteFilteredItems<SuggestionGroup>()
return (
<AutocompleteList className="max-h-72 rounded-lg border border-divider-subtle bg-components-panel-bg p-1 shadow-xs">
{groups.map((group, groupIndex) => (
<AutocompleteGroup key={group.label} items={group.items}>
{groupIndex > 0 && <AutocompleteSeparator />}
<AutocompleteLabel>{group.label}</AutocompleteLabel>
<AutocompleteCollection>
{(item: Suggestion) => (
<AutocompleteItem key={item.value} value={item} className="grid grid-cols-[1fr_auto]">
<span className="flex min-w-0 items-center gap-2">
{item.icon && <span className={cn(item.icon, 'size-4 shrink-0 text-text-tertiary')} aria-hidden="true" />}
<span className="min-w-0">
<AutocompleteItemText className="block px-0">{item.label}</AutocompleteItemText>
<span className="block truncate system-xs-regular text-text-tertiary">{item.description}</span>
</span>
</span>
<kbd className="rounded-md border border-divider-subtle bg-components-badge-bg-dimm px-1.5 py-0.5 text-text-quaternary system-2xs-medium">
Enter
</kbd>
</AutocompleteItem>
)}
</AutocompleteCollection>
</AutocompleteGroup>
))}
</AutocompleteList>
)
}
const LimitedStatus = ({
total,
}: {
total: number
}) => {
const items = useAutocompleteFilteredItems<Suggestion>()
const hidden = Math.max(0, total - items.length)
return hidden > 0
? `${hidden} more suggestions hidden. Refine the query to narrow results.`
: `${items.length} suggestions available.`
}
const AsyncSearchDemo = () => {
const [value, setValue] = useState('agent')
const [loading, setLoading] = useState(false)
const [items, setItems] = useState(remoteSuggestions)
useEffect(() => {
setLoading(true)
const timeout = window.setTimeout(() => {
setItems(
value.trim()
? remoteSuggestions.filter(item => item.label.toLowerCase().includes(value.trim().toLowerCase()))
: remoteSuggestions,
)
setLoading(false)
}, 500)
return () => window.clearTimeout(timeout)
}, [value])
return (
<div className={inputWidth}>
<Autocomplete
items={items}
value={value}
onValueChange={setValue}
itemToStringValue={getSuggestionLabel}
openOnInputClick
>
<AutocompleteInputGroup>
<span className="i-ri-cloud-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Search remote resources…" aria-label="Search remote resources" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteStatus>
{loading ? 'Loading suggestions…' : `${items.length} remote suggestions`}
</AutocompleteStatus>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<SuggestionItem key={item.value} item={item} index={index} />
)}
</AutocompleteList>
<AutocompleteEmpty>No remote suggestion. Keep the typed query.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
)
}
const VirtualizedSuggestionList = ({
virtualizerRef,
}: {
virtualizerRef: RefObject<StoryVirtualizer | null>
}) => {
const scrollRef = useRef<HTMLDivElement | null>(null)
const filteredItems = useAutocompleteFilteredItems<Suggestion>()
const virtualizer = useVirtualizer({
count: filteredItems.length,
getScrollElement: () => scrollRef.current,
estimateSize: () => 44,
overscan: 6,
})
useEffect(() => {
virtualizerRef.current = virtualizer
return () => {
virtualizerRef.current = null
}
}, [virtualizer, virtualizerRef])
return (
<div
ref={scrollRef}
className="max-h-[min(22rem,var(--available-height))] overflow-y-auto overflow-x-hidden overscroll-contain outline-hidden"
>
<AutocompleteList
className="relative max-h-none overflow-visible p-0"
style={{ height: virtualizer.getTotalSize() }}
>
{virtualizer.getVirtualItems().map((virtualItem) => {
const item = filteredItems[virtualItem.index]
if (!item)
return null
return (
<div
key={virtualItem.key}
className="absolute top-0 left-0 w-full"
style={{
height: virtualItem.size,
transform: `translateY(${virtualItem.start}px)`,
}}
>
<SuggestionItem item={item} index={virtualItem.index} />
</div>
)
})}
</AutocompleteList>
</div>
)
}
const VirtualizedStatus = () => {
const filteredItems = useAutocompleteFilteredItems<Suggestion>()
return (
<AutocompleteStatus className="border-b border-divider-subtle text-text-quaternary tabular-nums">
{filteredItems.length}
{' '}
matching suggestions. Selecting one only replaces the input text.
</AutocompleteStatus>
)
}
const FuzzyHighlight = ({
text,
query,
}: {
text: string
query: string
}) => {
const parts = useMemo(() => {
const trimmed = query.trim()
if (!trimmed)
return [text]
const escaped = trimmed.slice(0, 80).replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
return text.split(new RegExp(`(${escaped})`, 'i'))
}, [query, text])
return (
<>
{parts.map((part, index) => (
part.toLowerCase() === query.trim().toLowerCase()
? <mark key={`${part}-${index}`} className="bg-transparent text-text-accent">{part}</mark>
: part
))}
</>
)
}
const FuzzyMatchingDemo = () => {
const [value, setValue] = useState('retr')
const { contains } = useAutocompleteFilter({ sensitivity: 'base' })
return (
<div className={inputWidth}>
<Autocomplete
items={workflowSuggestions}
value={value}
onValueChange={setValue}
filter={contains}
itemToStringValue={getSuggestionLabel}
openOnInputClick
>
<AutocompleteInputGroup>
<span className="i-ri-sparkling-2-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Fuzzy search workflow suggestions…" aria-label="Fuzzy search workflow suggestions" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<AutocompleteItem key={item.value} value={item} index={index}>
{item.icon && <span className={cn(item.icon, 'size-4 shrink-0 text-text-tertiary')} aria-hidden="true" />}
<div className="min-w-0 grow">
<AutocompleteItemText className="block px-0">
<FuzzyHighlight text={item.label} query={value} />
</AutocompleteItemText>
<span className="block truncate system-xs-regular text-text-tertiary">{item.description}</span>
</div>
</AutocompleteItem>
)}
</AutocompleteList>
<AutocompleteEmpty>No workflow suggestion. Keep typing freely.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
)
}
const meta = {
title: 'Base/UI/Autocomplete',
component: Autocomplete,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Compound autocomplete built on Base UI Autocomplete. Use it for free-form inputs where suggestions can replace or complete the typed text, but selection is not persistent state.',
},
},
},
tags: ['autodocs'],
} satisfies Meta<typeof Autocomplete>
export default meta
type Story = StoryObj<typeof meta>
export const SearchTags: Story = {
render: () => (
<div className={inputWidth}>
<BasicTagAutocomplete />
</div>
),
}
export const Sizes: Story = {
render: () => (
<div className="flex flex-col gap-3">
{(['small', 'medium', 'large'] as const).map(size => (
<div key={size} className={inputWidth}>
<BasicTagAutocomplete size={size} />
</div>
))}
</div>
),
}
export const InlineAutocomplete: Story = {
render: () => (
<div className={inputWidth}>
<Autocomplete
items={promptCompletions}
itemToStringValue={getSuggestionLabel}
mode="both"
openOnInputClick
>
<AutocompleteInputGroup>
<span className="i-ri-text-snippet ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Type a prompt starter…" aria-label="Type a prompt starter" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<SuggestionItem key={item.value} item={item} index={index} dense />
)}
</AutocompleteList>
<AutocompleteEmpty>No inline completion. Continue typing freely.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
),
}
export const GroupedSuggestions: Story = {
render: () => (
<div className={inputWidth}>
<Autocomplete
items={groupedSuggestions}
itemToStringValue={getSuggestionLabel}
openOnInputClick
>
<AutocompleteInputGroup>
<span className="i-ri-command-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Search tags, nodes, or prompt starters…" aria-label="Search tags, nodes, or prompt starters" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent popupClassName="w-[420px]">
<GroupedSuggestionList />
<AutocompleteEmpty>No suggestion. Use the text as entered.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
),
}
export const FuzzyMatching: Story = {
render: () => <FuzzyMatchingDemo />,
}
export const LimitResults: Story = {
render: () => (
<div className={inputWidth}>
<Autocomplete
items={workflowSuggestions}
itemToStringValue={getSuggestionLabel}
limit={5}
openOnInputClick
>
<AutocompleteInputGroup>
<span className="i-ri-tools-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Search workflow suggestions…" aria-label="Search workflow suggestions" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent popupClassName="w-[420px]">
<AutocompleteStatus className="border-b border-divider-subtle">
<LimitedStatus total={workflowSuggestions.length} />
</AutocompleteStatus>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<SuggestionItem key={item.value} item={item} index={index} />
)}
</AutocompleteList>
<AutocompleteEmpty>No suggestion. Submit the typed text instead.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
),
}
export const CommandPalette: Story = {
render: () => (
<div className="w-[440px] rounded-xl border border-divider-subtle bg-components-panel-bg-alt p-2 shadow-xs">
<Autocomplete
open
inline
items={commandGroups}
itemToStringValue={getSuggestionLabel}
autoHighlight="always"
keepHighlight
>
<AutocompleteInputGroup className="mb-2">
<span className="i-ri-search-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Run a command…" aria-label="Run a command" />
<AutocompleteClear />
</AutocompleteInputGroup>
<CommandPaletteList />
</Autocomplete>
</div>
),
}
const VirtualizedLongSuggestionsDemo = () => {
const virtualizerRef = useRef<StoryVirtualizer | null>(null)
return (
<div className={inputWidth}>
<Autocomplete
items={virtualizedSuggestions}
itemToStringValue={getSuggestionLabel}
virtualized
openOnInputClick
onItemHighlighted={(item, details) => {
scrollHighlightedVirtualItem(item, details, virtualizerRef.current)
}}
>
<AutocompleteInputGroup>
<span className="i-ri-search-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Search 1,000 workspace suggestions…" aria-label="Search 1,000 workspace suggestions" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent popupClassName="w-[440px] p-1">
<VirtualizedStatus />
<VirtualizedSuggestionList virtualizerRef={virtualizerRef} />
<AutocompleteEmpty>No suggestion. Free-form text is still valid.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
)
}
export const VirtualizedLongSuggestions: Story = {
render: () => <VirtualizedLongSuggestionsDemo />,
}
export const AsyncSearch: Story = {
render: () => <AsyncSearchDemo />,
}
export const Empty: Story = {
render: () => (
<div className={inputWidth}>
<Autocomplete
items={tagSuggestions}
itemToStringValue={getSuggestionLabel}
defaultValue="private-release-note"
openOnInputClick
>
<AutocompleteInputGroup>
<span className="i-ri-search-line ml-2 size-4 shrink-0 text-text-tertiary" aria-hidden="true" />
<AutocompleteInput placeholder="Search tags or type a new one…" aria-label="Search tags or type a new one" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<TagSuggestionItem key={item.value} item={item} index={index} />
)}
</AutocompleteList>
<AutocompleteEmpty>No tag suggestion. The custom text remains valid.</AutocompleteEmpty>
</AutocompleteContent>
</Autocomplete>
</div>
),
}
export const DisabledAndReadOnly: Story = {
render: () => (
<div className="flex w-80 flex-col gap-3">
<Autocomplete items={tagSuggestions} itemToStringValue={getSuggestionLabel} defaultValue="feature" disabled>
<AutocompleteInputGroup>
<AutocompleteInput aria-label="Disabled tag autocomplete" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<TagSuggestionItem key={item.value} item={item} index={index} />
)}
</AutocompleteList>
</AutocompleteContent>
</Autocomplete>
<Autocomplete items={promptCompletions} itemToStringValue={getSuggestionLabel} defaultValue="summarize this conversation" readOnly>
<AutocompleteInputGroup>
<AutocompleteInput aria-label="Read-only prompt autocomplete" />
<AutocompleteClear />
<AutocompleteTrigger />
</AutocompleteInputGroup>
<AutocompleteContent>
<AutocompleteList>
{(item: Suggestion, index: number) => (
<SuggestionItem key={item.value} item={item} index={index} />
)}
</AutocompleteList>
</AutocompleteContent>
</Autocomplete>
</div>
),
}

View File

@ -0,0 +1,381 @@
'use client'
import type { VariantProps } from 'class-variance-authority'
import type { HTMLAttributes, ReactNode } from 'react'
import type { Placement } from '../placement'
import { Autocomplete as BaseAutocomplete } from '@base-ui/react/autocomplete'
import { cva } from 'class-variance-authority'
import { cn } from '../cn'
import {
overlayIndicatorClassName,
overlayLabelClassName,
overlayPopupAnimationClassName,
overlaySeparatorClassName,
} from '../overlay-shared'
import { parsePlacement } from '../placement'
export type { Placement }
export const Autocomplete = BaseAutocomplete.Root
export const AutocompleteValue = BaseAutocomplete.Value
export const AutocompleteGroup = BaseAutocomplete.Group
export const AutocompleteCollection = BaseAutocomplete.Collection
export const AutocompleteRow = BaseAutocomplete.Row
export const useAutocompleteFilter = BaseAutocomplete.useFilter
export const useAutocompleteFilteredItems = BaseAutocomplete.useFilteredItems
export type AutocompleteRootProps<ItemValue> = BaseAutocomplete.Root.Props<ItemValue>
export type AutocompleteRootChangeEventDetails = BaseAutocomplete.Root.ChangeEventDetails
export type AutocompleteRootHighlightEventDetails = BaseAutocomplete.Root.HighlightEventDetails
const autocompletePopupClassName = [
'w-(--anchor-width) max-w-[min(28rem,var(--available-width))] overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg outline-hidden',
'data-side-top:origin-bottom data-side-bottom:origin-top data-side-left:origin-right data-side-right:origin-left',
]
const autocompleteListClassName = [
'max-h-[min(20rem,var(--available-height))] overflow-y-auto overflow-x-hidden overscroll-contain p-1 outline-hidden scroll-py-1',
'data-empty:max-h-none data-empty:p-0',
]
const autocompleteItemClassName = [
'mx-1 flex min-h-8 cursor-pointer select-none items-center gap-2 rounded-lg px-2 py-1.5 text-text-secondary outline-hidden transition-colors',
'hover:bg-state-base-hover-alt hover:text-text-primary',
'data-highlighted:bg-state-base-hover data-highlighted:text-text-primary',
'data-disabled:cursor-not-allowed data-disabled:opacity-30 data-disabled:hover:bg-transparent data-disabled:hover:text-text-secondary',
'motion-reduce:transition-none',
]
const autocompleteInputGroupVariants = cva(
[
'group/autocomplete flex w-full min-w-0 items-center border border-transparent bg-components-input-bg-normal text-components-input-text-filled shadow-none outline-hidden transition-[background-color,border-color,box-shadow]',
'hover:border-components-input-border-hover hover:bg-components-input-bg-hover',
'focus-within:border-components-input-border-active focus-within:bg-components-input-bg-active focus-within:shadow-xs',
'data-focused:border-components-input-border-active data-focused:bg-components-input-bg-active data-focused:shadow-xs',
'data-disabled:cursor-not-allowed data-disabled:border-transparent data-disabled:bg-components-input-bg-disabled data-disabled:text-components-input-text-filled-disabled',
'data-disabled:hover:border-transparent data-disabled:hover:bg-components-input-bg-disabled',
'data-readonly:shadow-none data-readonly:hover:border-transparent data-readonly:hover:bg-components-input-bg-normal',
'motion-reduce:transition-none',
],
{
variants: {
size: {
small: 'h-6 rounded-md',
medium: 'h-8 rounded-lg',
large: 'h-9 rounded-[10px]',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type AutocompleteSize = NonNullable<VariantProps<typeof autocompleteInputGroupVariants>['size']>
export type AutocompleteInputGroupProps
= BaseAutocomplete.InputGroup.Props
& VariantProps<typeof autocompleteInputGroupVariants>
export function AutocompleteInputGroup({
className,
size = 'medium',
...props
}: AutocompleteInputGroupProps) {
return (
<BaseAutocomplete.InputGroup
className={cn(autocompleteInputGroupVariants({ size }), className)}
{...props}
/>
)
}
const autocompleteInputVariants = cva(
[
'w-0 min-w-0 flex-1 appearance-none border-0 bg-transparent text-components-input-text-filled caret-primary-600 outline-hidden',
'placeholder:text-components-input-text-placeholder',
'disabled:cursor-not-allowed disabled:text-components-input-text-filled-disabled disabled:placeholder:text-components-input-text-disabled',
'data-readonly:cursor-default',
],
{
variants: {
size: {
small: 'px-2 py-1 system-xs-regular',
medium: 'px-3 py-[7px] system-sm-regular',
large: 'px-4 py-2 system-md-regular',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type AutocompleteInputProps
= Omit<BaseAutocomplete.Input.Props, 'size'>
& VariantProps<typeof autocompleteInputVariants>
export function AutocompleteInput({
className,
size = 'medium',
type = 'text',
autoComplete = 'off',
...props
}: AutocompleteInputProps) {
return (
<BaseAutocomplete.Input
type={type}
autoComplete={autoComplete}
className={cn(autocompleteInputVariants({ size }), className)}
{...props}
/>
)
}
const autocompleteControlVariants = cva(
[
'flex shrink-0 touch-manipulation items-center justify-center rounded-md text-text-tertiary outline-hidden transition-colors',
'hover:bg-components-input-bg-hover hover:text-text-secondary focus-visible:bg-components-input-bg-hover focus-visible:text-text-secondary',
'focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:ring-inset',
'disabled:cursor-not-allowed disabled:hover:bg-transparent disabled:hover:text-text-tertiary disabled:focus-visible:bg-transparent disabled:focus-visible:ring-0',
'group-data-disabled/autocomplete:cursor-not-allowed group-data-disabled/autocomplete:hover:bg-transparent group-data-disabled/autocomplete:focus-visible:bg-transparent group-data-disabled/autocomplete:focus-visible:ring-0',
'group-data-readonly/autocomplete:hidden',
'motion-reduce:transition-none',
],
{
variants: {
size: {
small: 'mr-1 size-4',
medium: 'mr-1.5 size-5',
large: 'mr-2 size-5',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type AutocompleteControlProps
= Omit<BaseAutocomplete.Trigger.Props, 'className'>
& VariantProps<typeof autocompleteControlVariants>
& { className?: string }
export function AutocompleteTrigger({
className,
children,
size = 'medium',
type = 'button',
...props
}: AutocompleteControlProps) {
return (
<BaseAutocomplete.Trigger
type={type}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : 'Open autocomplete suggestions')}
className={cn(autocompleteControlVariants({ size }), className)}
{...props}
>
{children ?? <span className="i-ri-arrow-down-s-line size-4" aria-hidden="true" />}
</BaseAutocomplete.Trigger>
)
}
export type AutocompleteClearProps
= Omit<BaseAutocomplete.Clear.Props, 'className'>
& VariantProps<typeof autocompleteControlVariants>
& { className?: string }
export function AutocompleteClear({
className,
children,
size = 'medium',
type = 'button',
...props
}: AutocompleteClearProps) {
return (
<BaseAutocomplete.Clear
type={type}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : 'Clear autocomplete')}
className={cn(
autocompleteControlVariants({ size }),
'data-ending-style:opacity-0 data-starting-style:opacity-0',
className,
)}
{...props}
>
{children ?? <span className="i-ri-close-line size-4" aria-hidden="true" />}
</BaseAutocomplete.Clear>
)
}
export function AutocompleteIcon({
className,
children,
...props
}: BaseAutocomplete.Icon.Props) {
return (
<BaseAutocomplete.Icon
className={cn('flex shrink-0 items-center text-text-tertiary', className)}
{...props}
>
{children ?? <span className="i-ri-arrow-down-s-line size-4" aria-hidden="true" />}
</BaseAutocomplete.Icon>
)
}
type AutocompleteContentProps = {
children: ReactNode
placement?: Placement
sideOffset?: number
alignOffset?: number
className?: string
popupClassName?: string
portalProps?: Omit<BaseAutocomplete.Portal.Props, 'children'>
positionerProps?: Omit<
BaseAutocomplete.Positioner.Props,
'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset'
>
popupProps?: Omit<
BaseAutocomplete.Popup.Props,
'children' | 'className'
>
}
export function AutocompleteContent({
children,
placement = 'bottom-start',
sideOffset = 4,
alignOffset = 0,
className,
popupClassName,
portalProps,
positionerProps,
popupProps,
}: AutocompleteContentProps) {
const { side, align } = parsePlacement(placement)
return (
<BaseAutocomplete.Portal {...portalProps}>
<BaseAutocomplete.Positioner
side={side}
align={align}
sideOffset={sideOffset}
alignOffset={alignOffset}
className={cn('z-1002 outline-hidden', className)}
{...positionerProps}
>
<BaseAutocomplete.Popup
className={cn(
autocompletePopupClassName,
overlayPopupAnimationClassName,
popupClassName,
)}
{...popupProps}
>
{children}
</BaseAutocomplete.Popup>
</BaseAutocomplete.Positioner>
</BaseAutocomplete.Portal>
)
}
export function AutocompleteList({
className,
...props
}: BaseAutocomplete.List.Props) {
return (
<BaseAutocomplete.List
className={cn(autocompleteListClassName, className)}
{...props}
/>
)
}
export function AutocompleteItem({
className,
...props
}: BaseAutocomplete.Item.Props) {
return (
<BaseAutocomplete.Item
className={cn(autocompleteItemClassName, className)}
{...props}
/>
)
}
export type AutocompleteItemTextProps = HTMLAttributes<HTMLSpanElement>
export function AutocompleteItemText({
className,
...props
}: AutocompleteItemTextProps) {
return (
<span
className={cn('min-w-0 grow truncate px-1 system-sm-medium', className)}
{...props}
/>
)
}
export function AutocompleteLabel({
className,
...props
}: BaseAutocomplete.GroupLabel.Props) {
return (
<BaseAutocomplete.GroupLabel
className={cn(overlayLabelClassName, className)}
{...props}
/>
)
}
export function AutocompleteSeparator({
className,
...props
}: BaseAutocomplete.Separator.Props) {
return (
<BaseAutocomplete.Separator
className={cn(overlaySeparatorClassName, className)}
{...props}
/>
)
}
export function AutocompleteEmpty({
className,
...props
}: BaseAutocomplete.Empty.Props) {
return (
<BaseAutocomplete.Empty
className={cn('px-3 py-2 system-sm-regular text-text-tertiary', className)}
{...props}
/>
)
}
export function AutocompleteStatus({
className,
...props
}: BaseAutocomplete.Status.Props) {
return (
<BaseAutocomplete.Status
className={cn('px-3 py-2 system-sm-regular text-text-tertiary', className)}
{...props}
/>
)
}
export function AutocompleteItemIndicator({
className,
children,
...props
}: HTMLAttributes<HTMLSpanElement>) {
return (
<span
className={cn(overlayIndicatorClassName, className)}
{...props}
>
{children ?? <span className="i-ri-arrow-right-line size-4" aria-hidden="true" />}
</span>
)
}

View File

@ -0,0 +1,363 @@
import type { ReactNode } from 'react'
import { render } from 'vitest-browser-react'
import {
Combobox,
ComboboxChip,
ComboboxChipRemove,
ComboboxChips,
ComboboxClear,
ComboboxContent,
ComboboxEmpty,
ComboboxGroup,
ComboboxGroupLabel,
ComboboxInput,
ComboboxInputGroup,
ComboboxInputTrigger,
ComboboxItem,
ComboboxItemIndicator,
ComboboxItemText,
ComboboxLabel,
ComboboxList,
ComboboxSeparator,
ComboboxStatus,
ComboboxTrigger,
ComboboxValue,
} from '../index'
const renderWithSafeViewport = (ui: ReactNode) => render(
<div style={{ minHeight: '100vh', minWidth: '100vw', padding: '240px' }}>
{ui}
</div>,
)
const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement
const renderSelectLikeCombobox = ({
children,
open = false,
}: {
children?: ReactNode
open?: boolean
} = {}) => renderWithSafeViewport(
<Combobox open={open} defaultValue="workflow" items={['workflow', 'dataset']}>
{children ?? (
<>
<ComboboxLabel data-testid="label">Resource type</ComboboxLabel>
<ComboboxTrigger aria-label="Resource type" data-testid="trigger">
<ComboboxValue placeholder="Select resource" />
</ComboboxTrigger>
<ComboboxContent
positionerProps={{
'role': 'group',
'aria-label': 'combobox positioner',
}}
popupProps={{
'role': 'dialog',
'aria-label': 'combobox popup',
}}
>
<ComboboxStatus data-testid="status">2 options</ComboboxStatus>
<ComboboxList role="listbox" aria-label="combobox list" data-testid="list">
<ComboboxItem value="workflow">
<ComboboxItemText>Workflow</ComboboxItemText>
<ComboboxItemIndicator />
</ComboboxItem>
<ComboboxItem value="dataset">
<ComboboxItemText>Dataset</ComboboxItemText>
</ComboboxItem>
</ComboboxList>
<ComboboxEmpty data-testid="empty">No options</ComboboxEmpty>
</ComboboxContent>
</>
)}
</Combobox>,
)
const renderInputCombobox = ({
children,
open = false,
}: {
children?: ReactNode
open?: boolean
} = {}) => renderWithSafeViewport(
<Combobox open={open} defaultValue="workflow" items={['workflow', 'dataset']}>
{children ?? (
<>
<ComboboxInputGroup data-testid="input-group">
<ComboboxInput aria-label="Search resources" data-testid="input" />
<ComboboxClear data-testid="clear" />
<ComboboxInputTrigger data-testid="input-trigger" />
</ComboboxInputGroup>
<ComboboxContent popupProps={{ 'role': 'dialog', 'aria-label': 'combobox popup' }}>
<ComboboxList role="listbox" aria-label="combobox list">
<ComboboxItem value="workflow">
<ComboboxItemText>Workflow</ComboboxItemText>
<ComboboxItemIndicator />
</ComboboxItem>
</ComboboxList>
</ComboboxContent>
</>
)}
</Combobox>,
)
describe('Combobox wrappers', () => {
describe('Select-like trigger', () => {
it('should render label and apply medium trigger classes by default', async () => {
const screen = await renderSelectLikeCombobox()
await expect.element(screen.getByText('Resource type')).toHaveClass('system-sm-medium')
await expect.element(screen.getByRole('combobox', { name: 'Resource type' })).toHaveClass('rounded-lg')
await expect.element(screen.getByRole('combobox', { name: 'Resource type' })).toHaveClass('system-sm-regular')
})
it('should apply small and large trigger size variants', async () => {
const smallScreen = await renderSelectLikeCombobox({
children: (
<ComboboxTrigger aria-label="Small resource type" size="small">
<ComboboxValue placeholder="Select resource" />
</ComboboxTrigger>
),
})
await expect.element(smallScreen.getByRole('combobox', { name: 'Small resource type' })).toHaveClass('rounded-md')
await expect.element(smallScreen.getByRole('combobox', { name: 'Small resource type' })).toHaveClass('system-xs-regular')
const largeScreen = await renderSelectLikeCombobox({
children: (
<ComboboxTrigger aria-label="Large resource type" size="large">
<ComboboxValue placeholder="Select resource" />
</ComboboxTrigger>
),
})
await expect.element(largeScreen.getByRole('combobox', { name: 'Large resource type' })).toHaveClass('rounded-[10px]')
await expect.element(largeScreen.getByRole('combobox', { name: 'Large resource type' })).toHaveClass('system-md-regular')
})
it('should render default trigger icon and support hiding it', async () => {
const withIcon = await renderSelectLikeCombobox()
expect(withIcon.getByTestId('trigger').element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true')
const withoutIcon = await renderSelectLikeCombobox({
children: (
<ComboboxTrigger aria-label="Resource type without icon" icon={false}>
<ComboboxValue placeholder="Select resource" />
</ComboboxTrigger>
),
})
expect(withoutIcon.getByRole('combobox', { name: 'Resource type without icon' }).element().querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument()
})
})
describe('Input group and controls', () => {
it('should apply medium input group and input classes by default', async () => {
const screen = await renderInputCombobox()
await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-lg')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('px-3')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('system-sm-regular')
})
it('should apply large input group and input classes when large size is provided', async () => {
const screen = await renderInputCombobox({
children: (
<ComboboxInputGroup size="large" data-testid="input-group">
<ComboboxInput size="large" aria-label="Search resources" />
</ComboboxInputGroup>
),
})
await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-[10px]')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('px-4')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('system-md-regular')
})
it('should set input defaults and forward passthrough props', async () => {
const screen = await renderInputCombobox({
children: (
<ComboboxInputGroup>
<ComboboxInput
aria-label="Search resources"
className="custom-input"
placeholder="Find a resource"
required
/>
</ComboboxInputGroup>
),
})
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('autocomplete', 'off')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('type', 'text')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('placeholder', 'Find a resource')
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toBeRequired()
await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('custom-input')
})
it('should provide fallback aria labels and decorative icons for input controls', async () => {
const screen = await renderInputCombobox()
await expect.element(screen.getByRole('button', { name: 'Clear combobox' })).toHaveAttribute('type', 'button')
await expect.element(screen.getByRole('button', { name: 'Open combobox options' })).toHaveAttribute('type', 'button')
expect(screen.getByRole('button', { name: 'Clear combobox' }).element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true')
expect(screen.getByRole('button', { name: 'Open combobox options' }).element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true')
})
it('should rely on aria-labelledby when provided instead of injecting fallback labels', async () => {
const screen = await renderInputCombobox({
children: (
<>
<span id="clear-label">Clear from label</span>
<span id="trigger-label">Trigger from label</span>
<ComboboxInputGroup>
<ComboboxInput aria-label="Search resources" />
<ComboboxClear aria-labelledby="clear-label" />
<ComboboxInputTrigger aria-labelledby="trigger-label" />
</ComboboxInputGroup>
</>
),
})
await expect.element(screen.getByRole('button', { name: 'Clear from label' })).not.toHaveAttribute('aria-label')
await expect.element(screen.getByRole('button', { name: 'Trigger from label' })).not.toHaveAttribute('aria-label')
})
})
describe('Content and options', () => {
it('should use default overlay placement and Dify popup classes', async () => {
const screen = await renderSelectLikeCombobox({ open: true })
await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'bottom')
await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-align', 'start')
await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-1002')
await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('rounded-xl')
await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('w-(--anchor-width)')
await expect.element(screen.getByRole('listbox', { name: 'combobox list' })).toHaveClass('scroll-py-1')
})
it('should apply custom placement side and passthrough popup props', async () => {
const onPopupClick = vi.fn()
const screen = await renderWithSafeViewport(
<Combobox open defaultValue="workflow" items={['workflow']}>
<ComboboxTrigger aria-label="Resource type">
<ComboboxValue />
</ComboboxTrigger>
<ComboboxContent
placement="top-end"
sideOffset={12}
alignOffset={6}
positionerProps={{ 'role': 'group', 'aria-label': 'combobox positioner' }}
popupProps={{
'role': 'dialog',
'aria-label': 'combobox popup',
'onClick': onPopupClick,
}}
>
<ComboboxList role="listbox" aria-label="combobox list">
<ComboboxItem value="workflow">
<ComboboxItemText>Workflow</ComboboxItemText>
</ComboboxItem>
</ComboboxList>
</ComboboxContent>
</Combobox>,
)
asHTMLElement(screen.getByRole('dialog', { name: 'combobox popup' }).element()).click()
await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'top')
expect(onPopupClick).toHaveBeenCalledTimes(1)
})
it('should render item text indicator status and empty wrappers with design classes', async () => {
const screen = await renderSelectLikeCombobox({ open: true })
await expect.element(screen.getByTestId('list').getByText('Workflow')).toHaveClass('system-sm-medium')
await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary')
await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular')
expect(screen.getByTestId('list').getByText('Workflow').element().parentElement?.querySelector('.i-ri-check-line')).toHaveAttribute('aria-hidden', 'true')
})
it('should forward custom classes to group label separator item text and indicator', async () => {
const screen = await renderWithSafeViewport(
<Combobox open defaultValue="workflow" items={['workflow']}>
<ComboboxTrigger aria-label="Resource type">
<ComboboxValue />
</ComboboxTrigger>
<ComboboxContent popupProps={{ 'role': 'dialog', 'aria-label': 'combobox popup' }}>
<ComboboxList role="listbox" aria-label="combobox list" data-testid="custom-list">
<ComboboxGroup items={['workflow']}>
<ComboboxGroupLabel className="custom-label">Resources</ComboboxGroupLabel>
<ComboboxSeparator className="custom-separator" data-testid="separator" />
<ComboboxItem value="workflow" className="custom-item">
<ComboboxItemText className="custom-text">Workflow</ComboboxItemText>
<ComboboxItemIndicator className="custom-indicator" data-testid="indicator" />
</ComboboxItem>
</ComboboxGroup>
</ComboboxList>
</ComboboxContent>
</Combobox>,
)
await expect.element(screen.getByText('Resources')).toHaveClass('custom-label')
await expect.element(screen.getByTestId('separator')).toHaveClass('custom-separator')
await expect.element(screen.getByRole('option', { name: 'Workflow' })).toHaveClass('custom-item')
await expect.element(screen.getByTestId('custom-list').getByText('Workflow')).toHaveClass('custom-text')
await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator')
})
})
describe('Multiple selection chips', () => {
it('should render chip wrappers and default remove button label', async () => {
const screen = await renderWithSafeViewport(
<Combobox multiple defaultValue={['maya']} items={['maya', 'nora']}>
<ComboboxInputGroup>
<ComboboxValue>
{(selectedValue: string[]) => (
<ComboboxChips className="custom-chips" data-testid="chips">
{selectedValue.map(item => (
<ComboboxChip key={item} className="custom-chip">
<span>{item}</span>
<ComboboxChipRemove data-testid="remove-chip" />
</ComboboxChip>
))}
</ComboboxChips>
)}
</ComboboxValue>
<ComboboxInput aria-label="Reviewers" />
</ComboboxInputGroup>
</Combobox>,
)
await expect.element(screen.getByTestId('chips')).toHaveClass('custom-chips')
await expect.element(screen.getByText('maya').element().parentElement!).toHaveClass('custom-chip')
await expect.element(screen.getByRole('button', { name: 'Remove selected item' })).toHaveAttribute('type', 'button')
expect(screen.getByTestId('remove-chip').element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true')
})
it('should preserve chip remove aria-labelledby over fallback label', async () => {
const screen = await renderWithSafeViewport(
<Combobox multiple defaultValue={['maya']} items={['maya']}>
<ComboboxInputGroup>
<ComboboxValue>
{(selectedValue: string[]) => (
<ComboboxChips>
{selectedValue.map(item => (
<ComboboxChip key={item}>
<span id="remove-maya">Remove Maya</span>
<ComboboxChipRemove aria-labelledby="remove-maya" />
</ComboboxChip>
))}
</ComboboxChips>
)}
</ComboboxValue>
<ComboboxInput aria-label="Reviewers" />
</ComboboxInputGroup>
</Combobox>,
)
await expect.element(screen.getByRole('button', { name: 'Remove Maya' })).not.toHaveAttribute('aria-label')
})
})
})

View File

@ -0,0 +1,618 @@
import type { Meta, StoryObj } from '@storybook/react-vite'
import type { Virtualizer } from '@tanstack/react-virtual'
import type { RefObject } from 'react'
import { useVirtualizer } from '@tanstack/react-virtual'
import { useEffect, useRef, useState } from 'react'
import {
Combobox,
ComboboxChip,
ComboboxChipRemove,
ComboboxChips,
ComboboxClear,
ComboboxCollection,
ComboboxContent,
ComboboxEmpty,
ComboboxGroup,
ComboboxGroupLabel,
ComboboxInput,
ComboboxInputGroup,
ComboboxInputTrigger,
ComboboxItem,
ComboboxItemIndicator,
ComboboxItemText,
ComboboxLabel,
ComboboxList,
ComboboxSeparator,
ComboboxStatus,
ComboboxTrigger,
ComboboxValue,
useComboboxFilteredItems,
} from '.'
import { cn } from '../cn'
type Option = {
value: string
label: string
meta?: string
icon?: string
disabled?: boolean
}
type OptionGroup = {
label: string
items: Option[]
}
const fieldWidth = 'w-80'
const wideFieldWidth = 'w-[520px]'
const nativeFieldLabelClassName = 'mb-1 block text-text-secondary system-sm-medium'
type StoryVirtualizer = Virtualizer<HTMLDivElement, Element>
const scrollHighlightedVirtualItem = (
item: unknown,
{
reason,
index,
}: {
reason: 'keyboard' | 'pointer' | 'none'
index: number
},
virtualizer: StoryVirtualizer | null,
) => {
if (!item || !virtualizer)
return
const isStart = index === 0
const isEnd = index === virtualizer.options.count - 1
const shouldScroll = reason === 'none' || (reason === 'keyboard' && (isStart || isEnd))
if (shouldScroll) {
queueMicrotask(() => {
virtualizer.scrollToIndex(index, { align: isEnd ? 'start' : 'end' })
})
}
}
const providerOptions: Option[] = [
{ value: 'openai', label: 'OpenAI', meta: 'GPT-5, GPT-4.1', icon: 'i-ri-openai-fill' },
{ value: 'anthropic', label: 'Anthropic', meta: 'Claude Opus, Sonnet', icon: 'i-ri-sparkling-2-line' },
{ value: 'google', label: 'Google', meta: 'Gemini 2.5', icon: 'i-ri-google-fill' },
{ value: 'azure-openai', label: 'Azure OpenAI', meta: 'Enterprise workspace', icon: 'i-ri-microsoft-fill' },
{ value: 'localai', label: 'LocalAI', meta: 'Self-hosted endpoint', icon: 'i-ri-server-line', disabled: true },
]
const dataSourceOptions: Option[] = [
{ value: 'knowledge-base', label: 'Knowledge Base', meta: 'Vector index', icon: 'i-ri-database-2-line' },
{ value: 'notion', label: 'Notion', meta: 'Synced pages', icon: 'i-ri-notion-fill' },
{ value: 'website', label: 'Website crawler', meta: 'Public URLs', icon: 'i-ri-global-line' },
{ value: 's3', label: 'S3 bucket', meta: 'Private files', icon: 'i-ri-cloud-line' },
{ value: 'slack', label: 'Slack', meta: 'Channel history', icon: 'i-ri-slack-fill' },
]
const reviewerOptions: Option[] = [
{ value: 'maya', label: 'Maya Chen', meta: 'Product owner' },
{ value: 'liam', label: 'Liam Brooks', meta: 'Prompt engineer' },
{ value: 'nora', label: 'Nora Park', meta: 'Data steward' },
{ value: 'owen', label: 'Owen Reed', meta: 'Security reviewer' },
{ value: 'yuki', label: 'Yuki Tanaka', meta: 'ML engineer' },
]
const toolGroups: OptionGroup[] = [
{
label: 'Retrieval',
items: [
{ value: 'dataset-search', label: 'Dataset search', meta: 'Search workspace knowledge', icon: 'i-ri-search-eye-line' },
{ value: 'web-scraper', label: 'Web scraper', meta: 'Fetch public pages', icon: 'i-ri-global-line' },
],
},
{
label: 'Actions',
items: [
{ value: 'http-request', label: 'HTTP request', meta: 'Call external APIs', icon: 'i-ri-terminal-box-line' },
{ value: 'code-runner', label: 'Code runner', meta: 'Execute sandboxed scripts', icon: 'i-ri-code-s-slash-line' },
],
},
{
label: 'Operations',
items: [
{ value: 'human-review', label: 'Human review', meta: 'Assign approval task', icon: 'i-ri-user-voice-line' },
{ value: 'audit-log', label: 'Audit log', meta: 'Record workflow events', icon: 'i-ri-file-list-3-line' },
],
},
]
const tagOptions: Option[] = [
{ value: 'rag', label: 'RAG' },
{ value: 'agent', label: 'Agent' },
{ value: 'production', label: 'Production' },
{ value: 'evaluation', label: 'Evaluation' },
{ value: 'finance', label: 'Finance' },
{ value: 'support', label: 'Support' },
]
const directoryOptions: Option[] = [
{ value: 'maya-chen', label: 'Maya Chen', meta: 'Product owner · maya@example.com', icon: 'i-ri-user-3-line' },
{ value: 'liam-brooks', label: 'Liam Brooks', meta: 'Prompt engineer · liam@example.com', icon: 'i-ri-user-3-line' },
{ value: 'nora-park', label: 'Nora Park', meta: 'Data steward · nora@example.com', icon: 'i-ri-user-3-line' },
{ value: 'owen-reed', label: 'Owen Reed', meta: 'Security reviewer · owen@example.com', icon: 'i-ri-shield-user-line' },
{ value: 'yuki-tanaka', label: 'Yuki Tanaka', meta: 'ML engineer · yuki@example.com', icon: 'i-ri-user-3-line' },
{ value: 'ava-martin', label: 'Ava Martin', meta: 'Support lead · ava@example.com', icon: 'i-ri-customer-service-2-line' },
]
const emptyOptions: Option[] = [
{ value: 'billing', label: 'Billing connector' },
{ value: 'zendesk', label: 'Zendesk' },
{ value: 'github', label: 'GitHub issues' },
]
const modelCatalogOptions: Option[] = Array.from({ length: 1000 }, (_, index) => {
const provider = ['OpenAI', 'Anthropic', 'Google', 'Mistral', 'DeepSeek'][index % 5]!
const family = ['chat', 'reasoning', 'vision', 'embedding'][index % 4]!
const number = new Intl.NumberFormat('en-US', {
minimumIntegerDigits: 4,
}).format(index + 1)
return {
value: `model-${index + 1}`,
label: `${provider} ${family} ${number}`,
meta: `${provider} provider · ${family}`,
icon: family === 'embedding'
? 'i-ri-vector-triangle'
: family === 'vision'
? 'i-ri-image-circle-line'
: family === 'reasoning'
? 'i-ri-brain-line'
: 'i-ri-chat-1-line',
}
})
const sizeOptions: Option[] = providerOptions.slice(0, 3)
const defaultProvider = providerOptions[0]!
const disabledProvider = providerOptions[1]!
const defaultDataSource = dataSourceOptions[0]!
const defaultPopupDataSource = dataSourceOptions[1]!
const readOnlyDataSource = dataSourceOptions[2]!
const defaultTool = toolGroups[0]!.items[0]!
const defaultReviewers = [reviewerOptions[0]!, reviewerOptions[2]!]
const defaultTag = tagOptions[2]!
const renderOptionItem = (option: Option, index?: number) => (
<ComboboxItem key={option.value} value={option} index={index} disabled={option.disabled}>
<ComboboxItemText className="flex items-center gap-2 px-0">
{option.icon && <span aria-hidden className={cn(option.icon, 'size-4 shrink-0 text-text-tertiary')} />}
<span className="min-w-0 flex-1">
<span className="block truncate text-text-secondary system-sm-medium">{option.label}</span>
{option.meta && <span className="block truncate text-text-tertiary system-xs-regular">{option.meta}</span>}
</span>
</ComboboxItemText>
<ComboboxItemIndicator />
</ComboboxItem>
)
const renderSimpleOptionItem = (option: Option, index?: number) => (
<ComboboxItem key={option.value} value={option} index={index}>
<ComboboxItemText>{option.label}</ComboboxItemText>
<ComboboxItemIndicator />
</ComboboxItem>
)
const PopupSearchInput = ({
label,
placeholder,
}: {
label: string
placeholder: string
}) => (
<ComboboxInputGroup className="mb-1 border-divider-subtle bg-components-input-bg-normal">
<span aria-hidden className="ml-2 i-ri-search-line size-4 shrink-0 text-text-tertiary" />
<ComboboxInput aria-label={label} placeholder={`${placeholder}`} className="pl-2" />
<ComboboxClear />
</ComboboxInputGroup>
)
const GroupedToolList = () => {
const groups = useComboboxFilteredItems<OptionGroup>()
return (
<ComboboxList className="p-0">
{groups.map((group, groupIndex) => (
<ComboboxGroup key={group.label} items={group.items}>
{groupIndex > 0 && <ComboboxSeparator />}
<ComboboxGroupLabel>{group.label}</ComboboxGroupLabel>
<ComboboxCollection>
{(option: Option) => renderOptionItem(option)}
</ComboboxCollection>
</ComboboxGroup>
))}
</ComboboxList>
)
}
const VirtualizedModelList = ({
virtualizerRef,
}: {
virtualizerRef: RefObject<StoryVirtualizer | null>
}) => {
const scrollRef = useRef<HTMLDivElement | null>(null)
const filteredItems = useComboboxFilteredItems<Option>()
const virtualizer = useVirtualizer({
count: filteredItems.length,
getScrollElement: () => scrollRef.current,
estimateSize: () => 42,
overscan: 6,
})
useEffect(() => {
virtualizerRef.current = virtualizer
return () => {
virtualizerRef.current = null
}
}, [virtualizer, virtualizerRef])
return (
<div
ref={scrollRef}
className="max-h-[min(22rem,var(--available-height))] overflow-y-auto overflow-x-hidden overscroll-contain outline-hidden"
>
<ComboboxList
className="relative max-h-none overflow-visible p-0"
style={{
height: virtualizer.getTotalSize(),
}}
>
{virtualizer.getVirtualItems().map((virtualItem) => {
const option = filteredItems[virtualItem.index]
if (!option)
return null
return (
<div
key={virtualItem.key}
className="absolute top-0 left-0 w-full"
style={{
height: virtualItem.size,
transform: `translateY(${virtualItem.start}px)`,
}}
>
{renderOptionItem(option, virtualItem.index)}
</div>
)
})}
</ComboboxList>
</div>
)
}
const FilteredModelStatus = () => {
const filteredItems = useComboboxFilteredItems<Option>()
return (
<ComboboxStatus className="border-y border-divider-subtle px-2 py-1 text-text-quaternary tabular-nums">
{filteredItems.length}
{' '}
matching models
</ComboboxStatus>
)
}
const VirtualizedLongListDemo = () => {
const [value, setValue] = useState<Option | null>(modelCatalogOptions[137]!)
const virtualizerRef = useRef<StoryVirtualizer | null>(null)
return (
<div className={fieldWidth}>
<Combobox
items={modelCatalogOptions}
value={value}
onValueChange={setValue}
virtualized
autoHighlight
onItemHighlighted={(item, details) => {
scrollHighlightedVirtualItem(item, details, virtualizerRef.current)
}}
>
<ComboboxLabel>Model catalog</ComboboxLabel>
<ComboboxTrigger aria-label="Model catalog">
<ComboboxValue placeholder="Select model" />
</ComboboxTrigger>
<ComboboxContent popupClassName="w-[440px] p-1">
<PopupSearchInput label="Filter model catalog" placeholder="Filter 1,000 models" />
<FilteredModelStatus />
<VirtualizedModelList virtualizerRef={virtualizerRef} />
<ComboboxEmpty>No model matches this filter</ComboboxEmpty>
</ComboboxContent>
</Combobox>
</div>
)
}
const AsyncDirectoryDemo = () => {
const [inputValue, setInputValue] = useState('ma')
const [value, setValue] = useState<Option | null>(null)
const [items, setItems] = useState(directoryOptions.slice(0, 3))
const [loading, setLoading] = useState(false)
useEffect(() => {
setLoading(true)
const timeout = window.setTimeout(() => {
const query = inputValue.trim().toLowerCase()
setItems(
query
? directoryOptions.filter(option => `${option.label} ${option.meta}`.toLowerCase().includes(query))
: directoryOptions.slice(0, 5),
)
setLoading(false)
}, 450)
return () => window.clearTimeout(timeout)
}, [inputValue])
return (
<div className={fieldWidth}>
<Combobox
items={value && !items.some(item => item.value === value.value) ? [value, ...items] : items}
value={value}
onValueChange={setValue}
inputValue={inputValue}
onInputValueChange={setInputValue}
autoHighlight
>
<label className={nativeFieldLabelClassName}>
Owner
<ComboboxInputGroup className="mt-1">
<span aria-hidden className="ml-3 i-ri-search-line size-4 shrink-0 text-text-tertiary" />
<ComboboxInput placeholder="Search owners…" className="pl-2" />
<ComboboxClear />
<ComboboxInputTrigger />
</ComboboxInputGroup>
</label>
<ComboboxContent popupClassName="w-[420px]">
<ComboboxStatus className="border-b border-divider-subtle">
{loading ? 'Loading directory matches…' : `${items.length} selectable owners`}
</ComboboxStatus>
<ComboboxList>{renderOptionItem}</ComboboxList>
<ComboboxEmpty>No owner matches this query</ComboboxEmpty>
</ComboboxContent>
</Combobox>
</div>
)
}
const meta = {
title: 'Base/UI/Combobox',
component: Combobox,
parameters: {
layout: 'centered',
docs: {
description: {
component: 'Compound combobox built on Base UI Combobox for searchable predefined selections. Compose triggers, inputs, lists, groups, status, empty states, and chips without importing Base UI primitives directly.',
},
},
},
tags: ['autodocs'],
} satisfies Meta<typeof Combobox>
export default meta
type Story = StoryObj<typeof meta>
export const SelectLikeDefault: Story = {
render: () => (
<div className={fieldWidth}>
<Combobox items={providerOptions} defaultValue={defaultProvider} autoHighlight>
<ComboboxLabel>Model provider</ComboboxLabel>
<ComboboxTrigger aria-label="Model provider">
<ComboboxValue placeholder="Select provider" />
</ComboboxTrigger>
<ComboboxContent popupClassName="p-1">
<PopupSearchInput label="Search model providers" placeholder="Search providers" />
<ComboboxList className="p-0">{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
</div>
),
}
export const PopupInputSearchableSelect: Story = {
render: () => (
<div className={fieldWidth}>
<Combobox items={dataSourceOptions} defaultValue={defaultPopupDataSource} autoHighlight>
<ComboboxLabel>Data source</ComboboxLabel>
<ComboboxTrigger aria-label="Data source">
<ComboboxValue placeholder="Choose source" />
</ComboboxTrigger>
<ComboboxContent popupClassName="p-1">
<PopupSearchInput label="Search data sources" placeholder="Search sources" />
<ComboboxList className="p-0">{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
</div>
),
}
export const AsyncSearchSingle: Story = {
render: () => <AsyncDirectoryDemo />,
}
export const InputGroupSearchable: Story = {
render: () => (
<div className={fieldWidth}>
<Combobox items={dataSourceOptions} defaultValue={defaultDataSource} autoHighlight>
<label className={nativeFieldLabelClassName}>
Connect source
<ComboboxInputGroup className="mt-1">
<span aria-hidden className="ml-3 i-ri-search-line size-4 shrink-0 text-text-tertiary" />
<ComboboxInput placeholder="Search data sources…" className="pl-2" />
<ComboboxClear />
<ComboboxInputTrigger />
</ComboboxInputGroup>
</label>
<ComboboxContent>
<ComboboxList>{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
</div>
),
}
export const Sizes: Story = {
render: () => (
<div className="flex w-80 flex-col gap-3">
{(['small', 'medium', 'large'] as const).map(size => (
<Combobox key={size} items={sizeOptions} defaultValue={defaultProvider} autoHighlight>
<ComboboxTrigger aria-label={`${size} model provider`} size={size}>
<ComboboxValue />
</ComboboxTrigger>
<ComboboxContent popupClassName="p-1">
<PopupSearchInput label={`Search ${size} model providers`} placeholder="Search providers" />
<ComboboxList className="p-0">{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
))}
</div>
),
}
export const Grouped: Story = {
render: () => (
<div className={fieldWidth}>
<Combobox items={toolGroups} defaultValue={defaultTool} autoHighlight>
<ComboboxLabel>Workflow tool</ComboboxLabel>
<ComboboxTrigger aria-label="Workflow tool">
<ComboboxValue placeholder="Select tool" />
</ComboboxTrigger>
<ComboboxContent popupClassName="p-1">
<PopupSearchInput label="Search workflow tools" placeholder="Search workflow tools" />
<GroupedToolList />
</ComboboxContent>
</Combobox>
</div>
),
}
const MultipleChipsDemo = () => {
const [value, setValue] = useState<Option[]>(defaultReviewers)
return (
<div className={wideFieldWidth}>
<Combobox items={reviewerOptions} multiple value={value} onValueChange={setValue} autoHighlight>
<label className={nativeFieldLabelClassName}>
Reviewers
<ComboboxInputGroup className="mt-1 h-auto min-h-8 flex-nowrap py-1">
<ComboboxValue>
{(selectedValue: Option[]) => (
<>
<ComboboxChips className="flex-nowrap">
{selectedValue.map(item => (
<ComboboxChip key={item.value}>
<span className="max-w-32 truncate">{item.label}</span>
<ComboboxChipRemove aria-label={`Remove ${item.label}`} />
</ComboboxChip>
))}
</ComboboxChips>
<ComboboxInput placeholder={selectedValue.length ? '' : 'Assign reviewers…'} className="min-w-16 px-2" />
</>
)}
</ComboboxValue>
<ComboboxClear />
<ComboboxInputTrigger />
</ComboboxInputGroup>
</label>
<ComboboxContent>
<ComboboxList>{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
</div>
)
}
export const MultipleChips: Story = {
render: () => <MultipleChipsDemo />,
}
export const VirtualizedLongList: Story = {
render: () => <VirtualizedLongListDemo />,
}
export const EmptyAndStatus: Story = {
render: () => (
<div className={fieldWidth}>
<Combobox items={emptyOptions} defaultInputValue="salesforce" autoHighlight>
<label className={nativeFieldLabelClassName}>
Connector
<ComboboxInputGroup className="mt-1">
<span aria-hidden className="ml-3 i-ri-search-line size-4 shrink-0 text-text-tertiary" />
<ComboboxInput placeholder="Search connectors…" className="pl-2" />
<ComboboxClear />
<ComboboxInputTrigger />
</ComboboxInputGroup>
</label>
<ComboboxContent>
<ComboboxStatus>Search workspace connectors</ComboboxStatus>
<ComboboxEmpty>No connectors found</ComboboxEmpty>
<ComboboxList>{renderSimpleOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
</div>
),
}
export const DisabledAndReadOnly: Story = {
render: () => (
<div className="flex w-80 flex-col gap-3">
<Combobox items={providerOptions} defaultValue={disabledProvider} disabled>
<ComboboxLabel>Disabled provider</ComboboxLabel>
<ComboboxTrigger aria-label="Disabled model provider">
<ComboboxValue />
</ComboboxTrigger>
<ComboboxContent popupClassName="p-1">
<PopupSearchInput label="Search disabled providers" placeholder="Search providers" />
<ComboboxList className="p-0">{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
<Combobox items={dataSourceOptions} defaultValue={readOnlyDataSource} readOnly>
<label className={nativeFieldLabelClassName}>
Read-only source
<ComboboxInputGroup className="mt-1">
<ComboboxInput placeholder="Read-only data source…" />
<ComboboxClear />
<ComboboxInputTrigger />
</ComboboxInputGroup>
</label>
<ComboboxContent>
<ComboboxList>{renderOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
</div>
),
}
const ControlledDemo = () => {
const [value, setValue] = useState<Option | null>(defaultTag)
return (
<div className="flex w-80 flex-col items-start gap-3">
<Combobox items={tagOptions} value={value} onValueChange={setValue}>
<ComboboxLabel>Default app tag</ComboboxLabel>
<ComboboxTrigger aria-label="Default app tag">
<ComboboxValue placeholder="Select tag" />
</ComboboxTrigger>
<ComboboxContent popupClassName="p-1">
<PopupSearchInput label="Search app tags" placeholder="Search tags" />
<ComboboxList className="p-0">{renderSimpleOptionItem}</ComboboxList>
</ComboboxContent>
</Combobox>
<span className="rounded-md border border-divider-subtle bg-components-panel-bg px-2 py-1 text-text-tertiary system-xs-regular">
Selected:
{' '}
{value?.label ?? 'None'}
</span>
</div>
)
}
export const Controlled: Story = {
render: () => <ControlledDemo />,
}

View File

@ -0,0 +1,497 @@
'use client'
import type { VariantProps } from 'class-variance-authority'
import type { HTMLAttributes, ReactNode } from 'react'
import type { Placement } from '../placement'
import { Combobox as BaseCombobox } from '@base-ui/react/combobox'
import { cva } from 'class-variance-authority'
import { cn } from '../cn'
import {
overlayIndicatorClassName,
overlayLabelClassName,
overlayPopupAnimationClassName,
overlaySeparatorClassName,
} from '../overlay-shared'
import { parsePlacement } from '../placement'
export type { Placement }
export const Combobox = BaseCombobox.Root
export const ComboboxValue = BaseCombobox.Value
export const ComboboxGroup = BaseCombobox.Group
export const ComboboxCollection = BaseCombobox.Collection
export const ComboboxRow = BaseCombobox.Row
export const useComboboxFilter = BaseCombobox.useFilter
export const useComboboxFilteredItems = BaseCombobox.useFilteredItems
export type ComboboxRootProps<Value, Multiple extends boolean | undefined = false>
= BaseCombobox.Root.Props<Value, Multiple>
export type ComboboxRootChangeEventDetails = BaseCombobox.Root.ChangeEventDetails
export type ComboboxRootHighlightEventDetails = BaseCombobox.Root.HighlightEventDetails
const comboboxPopupClassName = [
'w-(--anchor-width) max-w-[min(28rem,var(--available-width))] overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg outline-hidden',
'data-side-top:origin-bottom data-side-bottom:origin-top data-side-left:origin-right data-side-right:origin-left',
]
const comboboxListClassName = [
'max-h-[min(20rem,var(--available-height))] overflow-y-auto overflow-x-hidden overscroll-contain p-1 outline-hidden scroll-py-1',
'data-empty:max-h-none data-empty:p-0',
]
const comboboxItemClassName = [
'mx-1 grid min-h-8 cursor-pointer select-none grid-cols-[1fr_auto] items-center gap-2 rounded-lg px-2 py-1.5 text-text-secondary outline-hidden transition-colors',
'hover:bg-state-base-hover-alt hover:text-text-primary',
'data-highlighted:bg-state-base-hover data-highlighted:text-text-primary',
'data-selected:text-text-primary',
'data-disabled:cursor-not-allowed data-disabled:opacity-30 data-disabled:hover:bg-transparent data-disabled:hover:text-text-secondary',
'motion-reduce:transition-none',
]
const comboboxTriggerVariants = cva(
[
'group/combobox-trigger flex w-full min-w-0 items-center border-0 bg-components-input-bg-normal text-left text-components-input-text-filled outline-hidden transition-colors',
'hover:bg-state-base-hover-alt focus-visible:bg-state-base-hover-alt data-open:bg-state-base-hover-alt',
'focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:ring-inset',
'data-placeholder:text-components-input-text-placeholder',
'data-readonly:cursor-default data-readonly:bg-transparent data-readonly:hover:bg-transparent',
'data-disabled:cursor-not-allowed data-disabled:bg-components-input-bg-disabled data-disabled:text-components-input-text-filled-disabled data-disabled:hover:bg-components-input-bg-disabled',
'data-disabled:data-placeholder:text-components-input-text-disabled',
'motion-reduce:transition-none',
],
{
variants: {
size: {
small: 'h-6 gap-px rounded-md px-2 py-1 system-xs-regular',
medium: 'h-8 gap-0.5 rounded-lg px-3 py-2 system-sm-regular',
large: 'h-9 gap-0.5 rounded-[10px] px-4 py-2 system-md-regular',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type ComboboxSize = NonNullable<VariantProps<typeof comboboxTriggerVariants>['size']>
type ComboboxTriggerProps
= Omit<BaseCombobox.Trigger.Props, 'className'>
& VariantProps<typeof comboboxTriggerVariants>
& {
className?: string
icon?: ReactNode | false
}
export function ComboboxTrigger({
className,
children,
icon,
size,
type = 'button',
...props
}: ComboboxTriggerProps) {
return (
<BaseCombobox.Trigger
type={type}
className={cn(comboboxTriggerVariants({ size, className }))}
{...props}
>
<span className="min-w-0 grow truncate">
{children}
</span>
{icon !== false && (
<BaseCombobox.Icon className="shrink-0 text-text-quaternary transition-colors group-hover/combobox-trigger:text-text-secondary group-data-open/combobox-trigger:text-text-secondary group-data-readonly/combobox-trigger:hidden">
{icon ?? <span className="i-ri-arrow-down-s-line h-4 w-4" aria-hidden="true" />}
</BaseCombobox.Icon>
)}
</BaseCombobox.Trigger>
)
}
const comboboxInputGroupVariants = cva(
[
'group/combobox flex w-full min-w-0 items-center border border-transparent bg-components-input-bg-normal text-components-input-text-filled shadow-none outline-hidden transition-[background-color,border-color,box-shadow]',
'hover:border-components-input-border-hover hover:bg-components-input-bg-hover',
'focus-within:border-components-input-border-active focus-within:bg-components-input-bg-active focus-within:shadow-xs',
'data-focused:border-components-input-border-active data-focused:bg-components-input-bg-active data-focused:shadow-xs',
'data-open:border-components-input-border-active data-open:bg-components-input-bg-active',
'data-disabled:cursor-not-allowed data-disabled:border-transparent data-disabled:bg-components-input-bg-disabled data-disabled:text-components-input-text-filled-disabled',
'data-disabled:hover:border-transparent data-disabled:hover:bg-components-input-bg-disabled',
'data-readonly:shadow-none data-readonly:hover:border-transparent data-readonly:hover:bg-components-input-bg-normal',
'motion-reduce:transition-none',
],
{
variants: {
size: {
small: 'min-h-6 rounded-md',
medium: 'min-h-8 rounded-lg',
large: 'min-h-9 rounded-[10px]',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type ComboboxInputGroupProps
= BaseCombobox.InputGroup.Props
& VariantProps<typeof comboboxInputGroupVariants>
export function ComboboxInputGroup({
className,
size = 'medium',
...props
}: ComboboxInputGroupProps) {
return (
<BaseCombobox.InputGroup
className={cn(comboboxInputGroupVariants({ size }), className)}
{...props}
/>
)
}
const comboboxInputVariants = cva(
[
'w-0 min-w-0 flex-1 appearance-none border-0 bg-transparent text-components-input-text-filled caret-primary-600 outline-hidden',
'placeholder:text-components-input-text-placeholder',
'disabled:cursor-not-allowed disabled:text-components-input-text-filled-disabled disabled:placeholder:text-components-input-text-disabled',
'data-readonly:cursor-default',
],
{
variants: {
size: {
small: 'px-2 py-1 system-xs-regular',
medium: 'px-3 py-[7px] system-sm-regular',
large: 'px-4 py-2 system-md-regular',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type ComboboxInputProps
= Omit<BaseCombobox.Input.Props, 'size'>
& VariantProps<typeof comboboxInputVariants>
export function ComboboxInput({
className,
size = 'medium',
type = 'text',
autoComplete = 'off',
...props
}: ComboboxInputProps) {
return (
<BaseCombobox.Input
type={type}
autoComplete={autoComplete}
className={cn(comboboxInputVariants({ size }), className)}
{...props}
/>
)
}
const comboboxControlVariants = cva(
[
'flex shrink-0 touch-manipulation items-center justify-center rounded-md text-text-tertiary outline-hidden transition-colors',
'hover:bg-components-input-bg-hover hover:text-text-secondary focus-visible:bg-components-input-bg-hover focus-visible:text-text-secondary',
'focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:ring-inset',
'disabled:cursor-not-allowed disabled:hover:bg-transparent disabled:hover:text-text-tertiary disabled:focus-visible:bg-transparent disabled:focus-visible:ring-0',
'group-data-disabled/combobox:cursor-not-allowed group-data-disabled/combobox:hover:bg-transparent group-data-disabled/combobox:focus-visible:bg-transparent group-data-disabled/combobox:focus-visible:ring-0',
'group-data-readonly/combobox:hidden',
'motion-reduce:transition-none',
],
{
variants: {
size: {
small: 'mr-1 size-4',
medium: 'mr-1.5 size-5',
large: 'mr-2 size-5',
},
},
defaultVariants: {
size: 'medium',
},
},
)
export type ComboboxClearProps
= Omit<BaseCombobox.Clear.Props, 'className'>
& VariantProps<typeof comboboxControlVariants>
& { className?: string }
export function ComboboxClear({
className,
children,
size = 'medium',
type = 'button',
...props
}: ComboboxClearProps) {
return (
<BaseCombobox.Clear
type={type}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : 'Clear combobox')}
className={cn(
comboboxControlVariants({ size }),
'data-ending-style:opacity-0 data-starting-style:opacity-0',
className,
)}
{...props}
>
{children ?? <span className="i-ri-close-line size-4" aria-hidden="true" />}
</BaseCombobox.Clear>
)
}
export type ComboboxInputTriggerProps
= Omit<BaseCombobox.Trigger.Props, 'className'>
& VariantProps<typeof comboboxControlVariants>
& { className?: string }
export function ComboboxInputTrigger({
className,
children,
size = 'medium',
type = 'button',
...props
}: ComboboxInputTriggerProps) {
return (
<BaseCombobox.Trigger
type={type}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : 'Open combobox options')}
className={cn(comboboxControlVariants({ size }), className)}
{...props}
>
{children ?? <span className="i-ri-arrow-down-s-line size-4" aria-hidden="true" />}
</BaseCombobox.Trigger>
)
}
export function ComboboxIcon({
className,
children,
...props
}: BaseCombobox.Icon.Props) {
return (
<BaseCombobox.Icon
className={cn('flex shrink-0 items-center text-text-tertiary', className)}
{...props}
>
{children ?? <span className="i-ri-arrow-down-s-line size-4" aria-hidden="true" />}
</BaseCombobox.Icon>
)
}
type ComboboxContentProps = {
children: ReactNode
placement?: Placement
sideOffset?: number
alignOffset?: number
className?: string
popupClassName?: string
portalProps?: Omit<BaseCombobox.Portal.Props, 'children'>
positionerProps?: Omit<
BaseCombobox.Positioner.Props,
'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset'
>
popupProps?: Omit<
BaseCombobox.Popup.Props,
'children' | 'className'
>
}
export function ComboboxContent({
children,
placement = 'bottom-start',
sideOffset = 4,
alignOffset = 0,
className,
popupClassName,
portalProps,
positionerProps,
popupProps,
}: ComboboxContentProps) {
const { side, align } = parsePlacement(placement)
return (
<BaseCombobox.Portal {...portalProps}>
<BaseCombobox.Positioner
side={side}
align={align}
sideOffset={sideOffset}
alignOffset={alignOffset}
className={cn('z-1002 outline-hidden', className)}
{...positionerProps}
>
<BaseCombobox.Popup
className={cn(
comboboxPopupClassName,
overlayPopupAnimationClassName,
popupClassName,
)}
{...popupProps}
>
{children}
</BaseCombobox.Popup>
</BaseCombobox.Positioner>
</BaseCombobox.Portal>
)
}
export function ComboboxList({
className,
...props
}: BaseCombobox.List.Props) {
return (
<BaseCombobox.List
className={cn(comboboxListClassName, className)}
{...props}
/>
)
}
export function ComboboxItem({
className,
...props
}: BaseCombobox.Item.Props) {
return (
<BaseCombobox.Item
className={cn(comboboxItemClassName, className)}
{...props}
/>
)
}
export type ComboboxItemTextProps = HTMLAttributes<HTMLSpanElement>
export function ComboboxItemText({
className,
...props
}: ComboboxItemTextProps) {
return (
<span
className={cn('min-w-0 grow truncate px-1 system-sm-medium', className)}
{...props}
/>
)
}
export function ComboboxItemIndicator({
className,
children,
...props
}: Omit<BaseCombobox.ItemIndicator.Props, 'children'> & { children?: ReactNode }) {
return (
<BaseCombobox.ItemIndicator
className={cn(overlayIndicatorClassName, className)}
{...props}
>
{children ?? <span className="i-ri-check-line h-4 w-4" aria-hidden="true" />}
</BaseCombobox.ItemIndicator>
)
}
export function ComboboxLabel({
className,
...props
}: BaseCombobox.Label.Props) {
return (
<BaseCombobox.Label
className={cn('mb-1 block text-text-secondary system-sm-medium', className)}
{...props}
/>
)
}
export function ComboboxGroupLabel({
className,
...props
}: BaseCombobox.GroupLabel.Props) {
return (
<BaseCombobox.GroupLabel
className={cn(overlayLabelClassName, className)}
{...props}
/>
)
}
export function ComboboxSeparator({
className,
...props
}: BaseCombobox.Separator.Props) {
return (
<BaseCombobox.Separator
className={cn(overlaySeparatorClassName, className)}
{...props}
/>
)
}
export function ComboboxEmpty({
className,
...props
}: BaseCombobox.Empty.Props) {
return (
<BaseCombobox.Empty
className={cn('px-3 py-2 system-sm-regular text-text-tertiary', className)}
{...props}
/>
)
}
export function ComboboxStatus({
className,
...props
}: BaseCombobox.Status.Props) {
return (
<BaseCombobox.Status
className={cn('px-3 py-2 system-sm-regular text-text-tertiary', className)}
{...props}
/>
)
}
export function ComboboxChips({
className,
...props
}: BaseCombobox.Chips.Props) {
return (
<BaseCombobox.Chips
className={cn('flex w-full min-w-0 flex-wrap items-center gap-1 px-1', className)}
{...props}
/>
)
}
export function ComboboxChip({
className,
...props
}: BaseCombobox.Chip.Props) {
return (
<BaseCombobox.Chip
className={cn('inline-flex max-w-full min-w-0 items-center gap-1 rounded-md bg-state-base-hover px-1.5 py-0.5 text-text-secondary system-xs-medium', className)}
{...props}
/>
)
}
export function ComboboxChipRemove({
className,
children,
type = 'button',
...props
}: BaseCombobox.ChipRemove.Props) {
return (
<BaseCombobox.ChipRemove
type={type}
aria-label={props['aria-label'] ?? (props['aria-labelledby'] ? undefined : 'Remove selected item')}
className={cn('flex size-3.5 shrink-0 items-center justify-center rounded-sm text-text-tertiary outline-hidden hover:bg-state-base-hover-alt hover:text-text-secondary focus-visible:ring-1 focus-visible:ring-components-input-border-active', className)}
{...props}
>
{children ?? <span className="i-ri-close-line size-3" aria-hidden="true" />}
</BaseCombobox.ChipRemove>
)
}

2458
pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ saveExact: true
catalogMode: prefer
dedupeDirectDeps: true
engineStrict: true
minimumReleaseAge: 1440
minimumReleaseAge: 0
optimisticRepeatInstall: true
verifyDepsBeforeRun: install
resolutionMode: time-based
@ -54,8 +54,8 @@ overrides:
yaml@>=2.0.0 <2.8.3: 2.8.3
yauzl@<3.2.1: 3.2.1
catalog:
'@amplitude/analytics-browser': 2.42.0
'@amplitude/plugin-session-replay-browser': 1.28.1
'@amplitude/analytics-browser': 2.42.1
'@amplitude/plugin-session-replay-browser': 1.29.0
'@antfu/eslint-config': 8.2.0
'@base-ui/react': 1.4.1
'@chromatic-com/storybook': 5.1.2
@ -65,11 +65,11 @@ catalog:
'@eslint-react/eslint-plugin': 3.0.0
'@eslint/js': 10.0.1
'@floating-ui/react': 0.27.19
'@formatjs/intl-localematcher': 0.8.4
'@formatjs/intl-localematcher': 0.8.6
'@headlessui/react': 2.2.10
'@heroicons/react': 2.2.0
'@hey-api/openapi-ts': 0.97.0
'@hono/node-server': 2.0.0
'@hey-api/openapi-ts': 0.97.1
'@hono/node-server': 2.0.1
'@iconify-json/heroicons': 1.2.3
'@iconify-json/ri': 1.2.10
'@lexical/code': 0.44.0
@ -85,42 +85,42 @@ catalog:
'@monaco-editor/react': 4.7.0
'@next/eslint-plugin-next': 16.2.4
'@next/mdx': 16.2.4
'@orpc/client': 1.14.0
'@orpc/contract': 1.14.0
'@orpc/openapi-client': 1.14.0
'@orpc/tanstack-query': 1.14.0
'@orpc/client': 1.14.1
'@orpc/contract': 1.14.1
'@orpc/openapi-client': 1.14.1
'@orpc/tanstack-query': 1.14.1
'@playwright/test': 1.59.1
'@remixicon/react': 4.9.0
'@rgrove/parse-xml': 4.2.0
'@sentry/react': 10.50.0
'@storybook/addon-docs': 10.3.5
'@storybook/addon-links': 10.3.5
'@storybook/addon-onboarding': 10.3.5
'@storybook/addon-themes': 10.3.5
'@storybook/nextjs-vite': 10.3.5
'@storybook/react': 10.3.5
'@storybook/react-vite': 10.3.5
'@sentry/react': 10.51.0
'@storybook/addon-docs': 10.3.6
'@storybook/addon-links': 10.3.6
'@storybook/addon-onboarding': 10.3.6
'@storybook/addon-themes': 10.3.6
'@storybook/nextjs-vite': 10.3.6
'@storybook/react': 10.3.6
'@storybook/react-vite': 10.3.6
'@streamdown/math': 1.0.2
'@svgdotjs/svg.js': 3.2.5
'@t3-oss/env-nextjs': 0.13.11
'@tailwindcss/postcss': 4.2.4
'@tailwindcss/typography': 0.5.19
'@tailwindcss/vite': 4.2.4
'@tanstack/eslint-plugin-query': 5.100.6
'@tanstack/eslint-plugin-query': 5.100.9
'@tanstack/react-devtools': 0.10.2
'@tanstack/react-form': 1.29.1
'@tanstack/react-form-devtools': 0.2.22
'@tanstack/react-hotkeys': 0.10.0
'@tanstack/react-query': 5.100.6
'@tanstack/react-query-devtools': 5.100.6
'@tanstack/react-query': 5.100.9
'@tanstack/react-query-devtools': 5.100.9
'@tanstack/react-virtual': 3.13.24
'@testing-library/dom': 10.4.1
'@testing-library/jest-dom': 6.9.1
'@testing-library/react': 16.3.2
'@testing-library/user-event': 14.6.1
'@tsslint/cli': 3.1.0
'@tsslint/compat-eslint': 3.1.0
'@tsslint/config': 3.1.0
'@tsslint/cli': 3.1.1
'@tsslint/compat-eslint': 3.1.1
'@tsslint/config': 3.1.1
'@types/js-cookie': 3.0.6
'@types/js-yaml': 4.0.9
'@types/negotiator': 0.6.4
@ -129,9 +129,9 @@ catalog:
'@types/react': 19.2.14
'@types/react-dom': 19.2.3
'@types/sortablejs': 1.15.9
'@typescript-eslint/eslint-plugin': 8.59.1
'@typescript-eslint/parser': 8.59.1
'@typescript/native-preview': 7.0.0-dev.20260428.1
'@typescript-eslint/eslint-plugin': 8.59.2
'@typescript-eslint/parser': 8.59.2
'@typescript/native-preview': 7.0.0-dev.20260505.1
'@vitejs/plugin-react': 6.0.1
'@vitejs/plugin-rsc': 0.5.25
'@vitest/coverage-v8': 4.1.5
@ -149,44 +149,45 @@ catalog:
cron-parser: 5.5.0
dayjs: 1.11.20
decimal.js: 10.6.0
dompurify: 3.4.1
dompurify: 3.4.2
echarts: 6.0.0
echarts-for-react: 3.0.6
elkjs: 0.11.1
embla-carousel-autoplay: 8.6.0
embla-carousel-react: 8.6.0
emoji-mart: 5.6.0
es-toolkit: 1.46.0
eslint: 10.2.1
eslint-markdown: 0.7.0
es-toolkit: 1.46.1
eslint: 10.3.0
eslint-markdown: 0.8.0
eslint-plugin-better-tailwindcss: 4.5.0
eslint-plugin-hyoban: 0.14.1
eslint-plugin-markdown-preferences: 0.41.1
eslint-plugin-no-barrel-files: 1.3.1
eslint-plugin-react-refresh: 0.5.2
eslint-plugin-sonarjs: 4.0.3
eslint-plugin-storybook: 10.3.5
eslint-plugin-storybook: 10.3.6
fast-deep-equal: 3.1.3
fuse.js: 7.2.0
happy-dom: 20.9.0
hast-util-to-jsx-runtime: 2.3.6
hono: 4.12.15
hono: 4.12.17
html-entities: 2.6.0
html-to-image: 1.11.13
i18next: 26.0.8
i18next-resources-to-backend: 1.2.1
iconify-import-svg: 0.2.0
immer: 11.1.4
jotai: 2.19.1
immer: 11.1.6
jotai: 2.20.0
js-audio-recorder: 1.0.7
js-cookie: 3.0.5
js-yaml: 4.1.1
jsonschema: 1.5.0
katex: 0.16.45
knip: 6.7.0
knip: 6.11.0
ky: 2.0.2
lamejs: 1.2.1
lexical: 0.44.0
loro-crdt: 1.12.0
loro-crdt: 1.12.1
mermaid: 11.14.0
mime: 4.1.0
mitt: 3.0.1
@ -196,14 +197,14 @@ catalog:
nuqs: 2.8.9
pinyin-pro: 3.28.1
playwright: 1.59.1
postcss: 8.5.12
postcss: 8.5.14
qrcode.react: 4.2.0
qs: 6.15.1
react: 19.2.5
react-18-input-autosize: 3.0.0
react-dom: 19.2.5
react-easy-crop: 5.5.7
react-hotkeys-hook: 5.2.4
react-hotkeys-hook: 5.3.2
react-i18next: 16.5.8
react-multi-email: 1.0.25
react-papaparse: 4.4.0
@ -220,25 +221,25 @@ catalog:
socket.io-client: 4.8.3
sortablejs: 1.15.7
std-semver: 1.0.8
storybook: 10.3.5
storybook: 10.3.6
streamdown: 2.5.0
string-ts: 2.3.1
tailwind-merge: 3.5.0
tailwindcss: 4.2.4
tldts: 7.0.29
tldts: 7.0.30
tsx: 4.21.0
typescript: 6.0.3
uglify-js: 3.19.3
unist-util-visit: 5.1.0
use-context-selector: 2.0.0
uuid: 14.0.0
vinext: 0.0.45
vinext: 0.0.47
vite: npm:@voidzero-dev/vite-plus-core@0.1.20
vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.20
vitest: npm:@voidzero-dev/vite-plus-test@0.1.20
vitest-browser-react: 2.2.0
vitest-canvas-mock: 1.1.4
zod: 4.3.6
zod: 4.4.3
zundo: 2.3.0
zustand: 5.0.12
zustand: 5.0.13

View File

@ -5,9 +5,9 @@
## Overlay Components (Mandatory)
- `../packages/dify-ui/README.md` is the permanent contract for overlay primitives, portals, root `isolation: isolate`, and the `z-1002` / `z-1003` layering.
- `./docs/overlay-migration.md` is the source of truth for the ongoing migration (deprecated import paths, allowlist, coexistence rules).
- `./docs/overlay-migration.md` is the source of truth for the ongoing migration (deprecated import paths and coexistence rules).
- In new or modified code, use only overlay primitives from `@langgenius/dify-ui/*`.
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding).
- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them.
## Query & Mutation (Mandatory)

View File

@ -0,0 +1,172 @@
import type { ReactNode } from 'react'
import * as React from 'react'
const DropdownMenuContext = React.createContext({
open: false,
onOpenChange: (_open: boolean) => {},
})
type DropdownMenuProps = {
children?: ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
type DropdownMenuTriggerProps = React.HTMLAttributes<HTMLElement> & {
children?: ReactNode
nativeButton?: boolean
render?: React.ReactElement
}
type DropdownMenuContentProps = React.HTMLAttributes<HTMLDivElement> & {
children?: ReactNode
placement?: string
sideOffset?: number
alignOffset?: number
popupClassName?: string
}
export const DropdownMenu = ({
children,
open,
onOpenChange,
}: DropdownMenuProps) => {
const [localOpen, setLocalOpen] = React.useState(false)
const resolvedOpen = open ?? localOpen
const handleOpenChange = React.useCallback((nextOpen: boolean) => {
setLocalOpen(nextOpen)
onOpenChange?.(nextOpen)
}, [onOpenChange])
return (
<DropdownMenuContext.Provider value={{ open: resolvedOpen, onOpenChange: handleOpenChange }}>
<div data-testid="dropdown-menu" data-open={String(resolvedOpen)}>
{children}
</div>
</DropdownMenuContext.Provider>
)
}
export const DropdownMenuTrigger = ({
children,
render,
nativeButton: _nativeButton,
onClick,
...props
}: DropdownMenuTriggerProps) => {
const { open, onOpenChange } = React.useContext(DropdownMenuContext)
const node = render ?? children
const isNativeButton = React.isValidElement(node) && node.type === 'button'
const handleClick = (event: React.MouseEvent<HTMLElement>) => {
onClick?.(event)
if (!event.defaultPrevented)
onOpenChange(!open)
}
if (React.isValidElement(node)) {
const triggerElement = node as React.ReactElement<Record<string, unknown>>
const childProps = (triggerElement.props ?? {}) as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
const triggerProps = props as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
const role = childProps.role ?? triggerProps.role ?? (!isNativeButton && (childProps['aria-label'] || triggerProps['aria-label']) ? 'button' : undefined)
return React.cloneElement(triggerElement, {
...props,
...childProps,
'data-testid': childProps['data-testid'] ?? triggerProps['data-testid'] ?? 'dropdown-menu-trigger',
role,
'tabIndex': childProps.tabIndex ?? triggerProps.tabIndex ?? (role === 'button' ? 0 : undefined),
'onClick': (event: React.MouseEvent<HTMLElement>) => {
childProps.onClick?.(event)
handleClick(event)
},
}, render ? (children ?? childProps.children) : childProps.children)
}
return (
<div data-testid="dropdown-menu-trigger" role="button" tabIndex={0} onClick={handleClick} {...props}>
{node}
</div>
)
}
export const DropdownMenuContent = ({
children,
className,
popupClassName,
placement,
sideOffset,
alignOffset,
...props
}: DropdownMenuContentProps) => {
const { open } = React.useContext(DropdownMenuContext)
if (!open)
return null
return (
<div
data-testid="dropdown-menu-content"
data-placement={placement}
data-side-offset={sideOffset}
data-align-offset={alignOffset}
className={className || popupClassName}
{...props}
>
{children}
</div>
)
}
export const DropdownMenuItem = ({
children,
onClick,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode }) => (
<div role="menuitem" onClick={onClick} {...props}>
{children}
</div>
)
export const DropdownMenuRadioGroup = ({
children,
onValueChange,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode, value?: unknown, onValueChange?: (value: unknown) => void }) => (
<div
role="radiogroup"
{...props}
data-on-value-change={onValueChange ? 'true' : undefined}
>
{React.Children.map(children, (child) => {
if (!React.isValidElement(child))
return child
return React.cloneElement(child as React.ReactElement<{ __onValueChange?: (value: unknown) => void }>, { __onValueChange: onValueChange })
})}
</div>
)
export const DropdownMenuRadioItem = ({
children,
value,
onClick,
__onValueChange,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode, value?: unknown, __onValueChange?: (value: unknown) => void }) => (
<div
role="radio"
onClick={(event) => {
onClick?.(event)
__onValueChange?.(value)
}}
{...props}
>
{children}
</div>
)
export const DropdownMenuRadioItemIndicator = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuCheckboxItem = DropdownMenuItem
export const DropdownMenuCheckboxItemIndicator = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuLabel = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuSeparator = (props: React.HTMLAttributes<HTMLDivElement>) => <div role="separator" {...props} />
export const DropdownMenuSub = ({ children }: { children?: ReactNode }) => <>{children}</>
export const DropdownMenuSubTrigger = DropdownMenuItem
export const DropdownMenuSubContent = ({ children }: { children?: ReactNode }) => <>{children}</>

View File

@ -23,17 +23,25 @@ type PopoverContentProps = React.HTMLAttributes<HTMLDivElement> & {
placement?: string
sideOffset?: number
alignOffset?: number
popupClassName?: string
positionerProps?: React.HTMLAttributes<HTMLDivElement>
popupProps?: React.HTMLAttributes<HTMLDivElement>
}
export const Popover = ({
children,
open = false,
open,
onOpenChange,
}: PopoverProps) => {
const [localOpen, setLocalOpen] = React.useState(false)
const resolvedOpen = open ?? localOpen
const handleOpenChange = React.useCallback((nextOpen: boolean) => {
setLocalOpen(nextOpen)
onOpenChange?.(nextOpen)
}, [onOpenChange])
React.useEffect(() => {
if (!open)
if (!resolvedOpen)
return
const handleMouseDown = (event: MouseEvent) => {
@ -41,12 +49,12 @@ export const Popover = ({
if (target?.closest?.('[data-popover-trigger="true"], [data-popover-content="true"]'))
return
onOpenChange?.(false)
handleOpenChange(false)
}
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape')
onOpenChange?.(false)
handleOpenChange(false)
}
document.addEventListener('mousedown', handleMouseDown)
@ -56,15 +64,15 @@ export const Popover = ({
document.removeEventListener('mousedown', handleMouseDown)
document.removeEventListener('keydown', handleKeyDown)
}
}, [open, onOpenChange])
}, [resolvedOpen, handleOpenChange])
return (
<PopoverContext.Provider value={{
open,
onOpenChange: onOpenChange ?? (() => {}),
open: resolvedOpen,
onOpenChange: handleOpenChange,
}}
>
<div data-testid="popover" data-open={String(open)}>
<div data-testid="popover" data-open={String(resolvedOpen)}>
{children}
</div>
</PopoverContext.Provider>
@ -84,11 +92,12 @@ export const PopoverTrigger = ({
if (React.isValidElement(node)) {
const triggerElement = node as React.ReactElement<Record<string, unknown>>
const childProps = (triggerElement.props ?? {}) as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
const triggerProps = props as React.HTMLAttributes<HTMLElement> & { 'data-testid'?: string }
return React.cloneElement(triggerElement, {
...props,
...childProps,
'data-testid': childProps['data-testid'] ?? 'popover-trigger',
'data-testid': childProps['data-testid'] ?? triggerProps['data-testid'] ?? 'popover-trigger',
'data-popover-trigger': 'true',
'onClick': (event: React.MouseEvent<HTMLElement>) => {
childProps.onClick?.(event)
@ -97,7 +106,7 @@ export const PopoverTrigger = ({
return
onOpenChange(!open)
},
})
}, render ? (children ?? childProps.children) : childProps.children)
}
return (
@ -123,6 +132,7 @@ export const PopoverContent = ({
placement,
sideOffset,
alignOffset,
popupClassName,
positionerProps,
popupProps,
...props
@ -139,7 +149,7 @@ export const PopoverContent = ({
data-placement={placement}
data-side-offset={sideOffset}
data-align-offset={alignOffset}
className={className}
className={className || popupClassName}
{...positionerProps}
{...popupProps}
{...props}

View File

@ -0,0 +1,65 @@
import type { ReactNode } from 'react'
import * as React from 'react'
const SelectContext = React.createContext({
value: undefined as unknown,
onValueChange: (_value: unknown) => {},
})
type SelectProps = {
children?: ReactNode
value?: unknown
onValueChange?: (value: unknown) => void
}
export const Select = ({
children,
value,
onValueChange,
}: SelectProps) => (
<SelectContext.Provider value={{ value, onValueChange: onValueChange ?? (() => {}) }}>
<div data-testid="select-root">{children}</div>
</SelectContext.Provider>
)
export const SelectTrigger = ({
children,
...props
}: React.ButtonHTMLAttributes<HTMLButtonElement> & { children?: ReactNode }) => (
<button type="button" {...props}>
{children}
</button>
)
export const SelectValue = ({ placeholder }: { placeholder?: ReactNode }) => <>{placeholder}</>
export const SelectContent = ({ children }: { children?: ReactNode }) => (
<div data-testid="select-content">{children}</div>
)
export const SelectItem = ({
children,
value,
onClick,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode, value?: unknown }) => {
const select = React.useContext(SelectContext)
return (
<div
role="option"
onClick={(event) => {
onClick?.(event)
select.onValueChange(value)
}}
{...props}
>
{children}
</div>
)
}
export const SelectItemText = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectItemIndicator = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectGroup = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectLabel = ({ children }: { children?: ReactNode }) => <>{children}</>
export const SelectSeparator = (props: React.HTMLAttributes<HTMLDivElement>) => <div role="separator" {...props} />

View File

@ -0,0 +1,95 @@
import type { ReactNode } from 'react'
import * as React from 'react'
const TooltipContext = React.createContext({
open: false,
onOpenChange: (_open: boolean) => {},
})
type TooltipProps = {
children?: ReactNode
open?: boolean
onOpenChange?: (open: boolean) => void
}
export const Tooltip = ({ children, open, onOpenChange }: TooltipProps) => {
const [localOpen, setLocalOpen] = React.useState(false)
const resolvedOpen = open ?? localOpen
const handleOpenChange = React.useCallback((nextOpen: boolean) => {
setLocalOpen(nextOpen)
onOpenChange?.(nextOpen)
}, [onOpenChange])
return (
<TooltipContext.Provider value={{ open: resolvedOpen, onOpenChange: handleOpenChange }}>
{children}
</TooltipContext.Provider>
)
}
export const TooltipTrigger = ({
children,
render,
nativeButton: _nativeButton,
...props
}: React.HTMLAttributes<HTMLElement> & { children?: ReactNode, render?: React.ReactElement, nativeButton?: boolean }) => {
const { open, onOpenChange } = React.useContext(TooltipContext)
const node = render ?? children
if (React.isValidElement(node)) {
const triggerElement = node as React.ReactElement<Record<string, unknown>>
const childProps = (triggerElement.props ?? {}) as React.HTMLAttributes<HTMLElement>
return React.cloneElement(triggerElement, {
...props,
...childProps,
onMouseEnter: (event: React.MouseEvent<HTMLElement>) => {
childProps.onMouseEnter?.(event)
props.onMouseEnter?.(event)
onOpenChange(true)
},
onMouseLeave: (event: React.MouseEvent<HTMLElement>) => {
childProps.onMouseLeave?.(event)
props.onMouseLeave?.(event)
onOpenChange(false)
},
onClick: (event: React.MouseEvent<HTMLElement>) => {
childProps.onClick?.(event)
props.onClick?.(event)
onOpenChange(!open)
},
})
}
return (
<span
{...props}
onMouseEnter={(event) => {
props.onMouseEnter?.(event)
onOpenChange(true)
}}
onMouseLeave={(event) => {
props.onMouseLeave?.(event)
onOpenChange(false)
}}
onClick={(event) => {
props.onClick?.(event)
onOpenChange(!open)
}}
>
{node}
</span>
)
}
export const TooltipContent = ({
children,
...props
}: React.HTMLAttributes<HTMLDivElement> & { children?: ReactNode }) => {
const { open } = React.useContext(TooltipContext)
if (!open)
return null
return <div {...props}>{children}</div>
}
export const TooltipProvider = ({ children }: { children?: ReactNode }) => <>{children}</>

View File

@ -95,37 +95,8 @@ vi.mock('@/app/components/workflow/utils', () => ({
getKeyboardKeyNameBySystem: (key: string) => key,
}))
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const React = await vi.importActual<typeof import('react')>('react')
const OpenContext = React.createContext(false)
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<OpenContext.Provider value={open}>
<div>{children}</div>
</OpenContext.Provider>
),
PortalToFollowElemTrigger: ({
children,
onClick,
}: {
children: React.ReactNode
onClick?: () => void
}) => (
<button type="button" data-testid="portal-trigger" onClick={onClick}>
{children}
</button>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = React.useContext(OpenContext)
return open ? <div>{children}</div> : null
},
}
})
vi.mock('@/app/components/base/tooltip', () => ({
default: ({ children }: { children?: React.ReactNode }) => <>{children}</>,
}))
vi.mock('@langgenius/dify-ui/dropdown-menu', () => import('@/__mocks__/base-ui-dropdown-menu'))
vi.mock('@langgenius/dify-ui/tooltip', () => import('@/__mocks__/base-ui-tooltip'))
vi.mock('@/app/components/app-sidebar/app-info', () => ({
default: ({

View File

@ -122,33 +122,7 @@ vi.mock('@/app/components/app/app-access-control', () => ({
default: () => <div data-testid="app-access-control" />,
}))
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const React = await vi.importActual<typeof import('react')>('react')
const OpenContext = React.createContext(false)
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<OpenContext.Provider value={open}>
<div>{children}</div>
</OpenContext.Provider>
),
PortalToFollowElemTrigger: ({
children,
onClick,
}: {
children: React.ReactNode
onClick?: () => void
}) => (
<div onClick={onClick}>
{children}
</div>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = React.useContext(OpenContext)
return open ? <div>{children}</div> : null
},
}
})
vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover'))
vi.mock('@/app/components/workflow/utils', () => ({
getKeyboardKeyCodeBySystem: () => 'ctrl',

View File

@ -53,6 +53,16 @@ vi.mock('@/next/navigation', () => ({
}),
}))
vi.mock('@tanstack/react-query', async (importOriginal) => {
const actual = await importOriginal<typeof import('@tanstack/react-query')>()
return {
...actual,
useQuery: () => ({
data: [],
}),
}
})
// Mock headless UI Popover so it renders content without transition
vi.mock('@headlessui/react', async () => {
const actual = await vi.importActual<typeof import('@headlessui/react')>('@headlessui/react')

View File

@ -9,7 +9,7 @@ import type { ReactElement, ReactNode } from 'react'
*/
import type { AppListResponse } from '@/models/app'
import type { App } from '@/types/app'
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { createSystemFeaturesWrapper } from '@/__tests__/utils/mock-system-features'
import List from '@/app/components/apps/list'
@ -92,6 +92,9 @@ vi.mock('@tanstack/react-query', async (importOriginal) => {
const actual = await importOriginal<typeof import('@tanstack/react-query')>()
return {
...actual,
useQuery: () => ({
data: [],
}),
useInfiniteQuery: () => ({
data: { pages: mockPages },
isLoading: mockIsLoading,
@ -360,13 +363,18 @@ describe('App List Browsing Flow', () => {
expect(input).toBeInTheDocument()
})
it('should allow typing in search input', () => {
it('should update search query when typing in search input', async () => {
mockPages = [createPage([createMockApp()])]
renderList()
const { onUrlUpdate } = renderList()
const input = document.querySelector('input')!
const input = screen.getByPlaceholderText('common.operation.search')
fireEvent.change(input, { target: { value: 'test search' } })
expect(input.value).toBe('test search')
await waitFor(() => {
expect(onUrlUpdate).toHaveBeenCalled()
})
const lastCall = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1]![0]
expect(lastCall.searchParams.get('keywords')).toBe('test search')
})
})

View File

@ -121,25 +121,7 @@ vi.mock('../../app-access-control', () => ({
),
}))
vi.mock('@/app/components/base/portal-to-follow-elem', async () => {
const ReactModule = await vi.importActual<typeof import('react')>('react')
const OpenContext = ReactModule.createContext(false)
return {
PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => (
<OpenContext.Provider value={open}>
<div>{children}</div>
</OpenContext.Provider>
),
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => (
<div onClick={onClick}>{children}</div>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => {
const open = ReactModule.useContext(OpenContext)
return open ? <div>{children}</div> : null
},
}
})
vi.mock('@langgenius/dify-ui/popover', () => import('@/__mocks__/base-ui-popover'))
vi.mock('../sections', () => ({
PublisherSummarySection: (props: Record<string, any>) => {

View File

@ -1,8 +1,8 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import SelectVarType from '../select-var-type'
describe('SelectVarType', () => {
it('should open the menu and return the selected variable type', () => {
it('should open the menu and return the selected variable type', async () => {
const onChange = vi.fn()
render(<SelectVarType onChange={onChange} />)
@ -11,6 +11,8 @@ describe('SelectVarType', () => {
fireEvent.click(screen.getByText('appDebug.variableConfig.checkbox'))
expect(onChange).toHaveBeenCalledWith('checkbox')
expect(screen.queryByText('appDebug.variableConfig.checkbox')).not.toBeInTheDocument()
await waitFor(() => {
expect(screen.queryByText('appDebug.variableConfig.checkbox')).not.toBeInTheDocument()
})
})
})

View File

@ -2,13 +2,7 @@
import { fireEvent, render, screen } from '@testing-library/react'
import TypeSelector from '../type-select'
vi.mock('@/app/components/base/portal-to-follow-elem', () => ({
PortalToFollowElem: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => (
<button type="button" onClick={onClick}>{children}</button>
),
PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
}))
vi.mock('@langgenius/dify-ui/select', () => import('@/__mocks__/base-ui-select'))
vi.mock('@/app/components/workflow/nodes/_base/components/input-var-type-icon', () => ({
default: ({ type }: { type: string }) => <span>{type}</span>,

View File

@ -1,16 +1,17 @@
'use client'
import type { FC } from 'react'
import type { InputVarType } from '@/app/components/workflow/types'
import { ChevronDownIcon } from '@heroicons/react/20/solid'
import { cn } from '@langgenius/dify-ui/cn'
import * as React from 'react'
import { useState } from 'react'
import Badge from '@/app/components/base/badge'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
Select,
SelectContent,
SelectItem,
SelectItemIndicator,
SelectItemText,
SelectTrigger,
} from '@langgenius/dify-ui/select'
import * as React from 'react'
import Badge from '@/app/components/base/badge'
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
import { inputVarTypeToVarType } from '@/app/components/workflow/nodes/_base/components/variable/utils'
@ -35,21 +36,26 @@ const TypeSelector: FC<Props> = ({
popupInnerClassName,
readonly,
}) => {
const [open, setOpen] = useState(false)
const selectedItem = value ? items.find(item => item.value === value) : undefined
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement="bottom-start"
offset={4}
<Select
value={selectedItem?.value}
readOnly={readonly}
onValueChange={(nextValue) => {
const selected = items.find(item => item.value === nextValue)
if (selected)
onSelect(selected)
}}
>
<PortalToFollowElemTrigger onClick={() => !readonly && setOpen(v => !v)} className="w-full">
<div
className={cn(`group flex h-9 items-center justify-between rounded-lg border-0 bg-components-input-bg-normal px-2 text-sm hover:bg-state-base-hover-alt ${readonly ? 'cursor-not-allowed' : 'cursor-pointer'}`)}
title={selectedItem?.name}
>
<SelectTrigger
className={cn(
'h-9 rounded-lg px-2 text-sm',
readonly ? 'cursor-not-allowed' : 'cursor-pointer',
)}
title={selectedItem?.name}
>
<div className="flex min-w-0 items-center justify-between">
<div className="flex items-center">
<InputVarTypeIcon type={selectedItem?.value as InputVarType} className="size-4 shrink-0 text-text-secondary" />
<span
@ -60,37 +66,35 @@ const TypeSelector: FC<Props> = ({
{selectedItem?.name}
</span>
</div>
<div className="flex items-center space-x-1">
<div className="ml-2 flex shrink-0 items-center space-x-1">
<Badge uppercase={false}>{inputVarTypeToVarType(selectedItem?.value as InputVarType)}</Badge>
<ChevronDownIcon className={cn('h-4 w-4 shrink-0 text-text-quaternary group-hover:text-text-secondary', open && 'text-text-secondary')} />
</div>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className="z-61">
<div
className={cn('w-[432px] rounded-md border-[0.5px] border-components-panel-border bg-components-panel-bg px-1 py-1 text-base shadow-lg focus:outline-hidden sm:text-sm', popupInnerClassName)}
>
{items.map((item: Item) => (
<div
key={item.value}
className="flex h-9 cursor-pointer items-center justify-between rounded-lg px-2 text-text-secondary hover:bg-state-base-hover"
title={item.name}
onClick={() => {
onSelect(item)
setOpen(false)
}}
</SelectTrigger>
<SelectContent
sideOffset={4}
popupClassName={cn('w-[432px] rounded-md px-1 py-1 text-base sm:text-sm', popupInnerClassName)}
listClassName="max-h-80 p-0"
>
{items.map((item: Item) => (
<SelectItem
key={item.value}
value={item.value}
className="h-9 justify-between px-2 text-text-secondary"
title={item.name}
>
<SelectItemText
className="flex items-center space-x-2 px-0"
>
<div className="flex items-center space-x-2">
<InputVarTypeIcon type={item.value} className="size-4 shrink-0 text-text-secondary" />
<span title={item.name}>{item.name}</span>
</div>
<Badge uppercase={false}>{inputVarTypeToVarType(item.value)}</Badge>
</div>
))}
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<InputVarTypeIcon type={item.value} className="size-4 shrink-0 text-text-secondary" />
<span title={item.name}>{item.name}</span>
</SelectItemText>
<Badge uppercase={false}>{inputVarTypeToVarType(item.value)}</Badge>
<SelectItemIndicator />
</SelectItem>
))}
</SelectContent>
</Select>
)
}

View File

@ -1,15 +1,16 @@
'use client'
import type { FC } from 'react'
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from '@langgenius/dify-ui/dropdown-menu'
import * as React from 'react'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import OperationBtn from '@/app/components/app/configuration/base/operation-btn'
import { ApiConnection } from '@/app/components/base/icons/src/vender/solid/development'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import InputVarTypeIcon from '@/app/components/workflow/nodes/_base/components/input-var-type-icon'
import { InputVarType } from '@/app/components/workflow/types'
@ -27,13 +28,14 @@ type ItemProps = {
const SelectItem: FC<ItemProps> = ({ text, type, value, Icon, onClick }) => {
return (
<div
className="flex h-8 cursor-pointer items-center rounded-lg px-3 hover:bg-state-base-hover"
<DropdownMenuItem
closeOnClick
className="h-8 rounded-lg px-3 text-text-primary"
onClick={() => onClick(value)}
>
{Icon ? <Icon className="h-4 w-4 text-text-secondary" /> : <InputVarTypeIcon type={type!} className="h-4 w-4 text-text-secondary" />}
<div className="ml-2 truncate text-xs text-text-primary">{text}</div>
</div>
</DropdownMenuItem>
)
}
@ -41,40 +43,36 @@ const SelectVarType: FC<Props> = ({
onChange,
}) => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
const handleChange = (value: string) => {
onChange(value)
setOpen(false)
}
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement="bottom-end"
offset={{
mainAxis: 8,
crossAxis: -2,
}}
>
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<DropdownMenu>
<DropdownMenuTrigger
nativeButton={false}
render={<div className="block" />}
>
<OperationBtn type="add" />
</PortalToFollowElemTrigger>
<PortalToFollowElemContent style={{ zIndex: 1000 }}>
<div className="min-w-[192px] rounded-lg border border-components-panel-border bg-components-panel-bg-blur shadow-lg backdrop-blur-xs">
<div className="p-1">
<SelectItem type={InputVarType.textInput} value="string" text={t('variableConfig.string', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.paragraph} value="paragraph" text={t('variableConfig.paragraph', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.select} value="select" text={t('variableConfig.select', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.number} value="number" text={t('variableConfig.number', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.checkbox} value="checkbox" text={t('variableConfig.checkbox', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
<div className="h-px border-t border-components-panel-border"></div>
<div className="p-1">
<SelectItem Icon={ApiConnection} value="api" text={t('variableConfig.apiBasedVar', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
</DropdownMenuTrigger>
<DropdownMenuContent
placement="bottom-end"
sideOffset={8}
alignOffset={-2}
popupClassName="min-w-[192px] rounded-lg border bg-components-panel-bg-blur p-0 backdrop-blur-xs"
>
<div className="p-1">
<SelectItem type={InputVarType.textInput} value="string" text={t('variableConfig.string', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.paragraph} value="paragraph" text={t('variableConfig.paragraph', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.select} value="select" text={t('variableConfig.select', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.number} value="number" text={t('variableConfig.number', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
<SelectItem type={InputVarType.checkbox} value="checkbox" text={t('variableConfig.checkbox', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
<DropdownMenuSeparator className="my-0" />
<div className="p-1">
<SelectItem Icon={ApiConnection} value="api" text={t('variableConfig.apiBasedVar', { ns: 'appDebug' })} onClick={handleChange}></SelectItem>
</div>
</DropdownMenuContent>
</DropdownMenu>
)
}
export default React.memo(SelectVarType)

View File

@ -841,11 +841,11 @@ describe('AssistantTypePicker', () => {
it('should have proper ARIA state for dropdown', async () => {
// Arrange
const user = userEvent.setup()
const { container } = renderComponent()
renderComponent()
// Act - Check initial state
const portalContainer = container.querySelector('[data-state]')
expect(portalContainer)!.toHaveAttribute('data-state', 'closed')
const triggerButton = screen.getByRole('button', { name: /chatAssistant\.name/i })
expect(triggerButton).toHaveAttribute('aria-expanded', 'false')
// Open dropdown
const trigger = screen.getByText(/chatAssistant.name/i)
@ -853,23 +853,22 @@ describe('AssistantTypePicker', () => {
// Assert - State should change to open
await waitFor(() => {
const openPortal = container.querySelector('[data-state="open"]')
expect(openPortal)!.toBeInTheDocument()
expect(triggerButton).toHaveAttribute('aria-expanded', 'true')
})
})
it('should have proper data-state attribute', () => {
// Arrange & Act
const { container } = renderComponent()
renderComponent()
// Assert - Portal should have data-state for accessibility
const portalContainer = container.querySelector('[data-state]')
expect(portalContainer)!.toBeInTheDocument()
expect(portalContainer)!.toHaveAttribute('data-state')
// Assert - Trigger should expose expanded state for accessibility
const triggerButton = screen.getByRole('button', { name: /chatAssistant\.name/i })
expect(triggerButton).toBeInTheDocument()
expect(triggerButton).toHaveAttribute('aria-expanded')
// Should start in closed state
// Should start in closed state
expect(portalContainer)!.toHaveAttribute('data-state', 'closed')
expect(triggerButton).toHaveAttribute('aria-expanded', 'false')
})
it('should maintain accessible structure for screen readers', () => {

Some files were not shown because too many files have changed in this diff Show More