diff --git a/README.md b/README.md index 778028fc76..e6f8d84931 100644 --- a/README.md +++ b/README.md @@ -76,10 +76,11 @@ The easiest way to start the Dify server is through [Docker Compose](docker/dock ```bash cd dify cd docker -cp .env.example .env -docker compose up -d +./dify-compose up -d ``` +On Windows PowerShell, run `.\dify-compose.ps1 up -d` from the `docker` directory. + After running, you can access the Dify dashboard in your browser at [http://localhost/install](http://localhost/install) and start the initialization process. #### Seeking help @@ -137,7 +138,7 @@ Star Dify on GitHub and be instantly notified of new releases. ### Custom configurations -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you need to customize the configuration, add only the values you want to override to `docker/.env`. The default values live in [`docker/.env.default`](docker/.env.default), and the full reference remains in [`docker/.env.example`](docker/.env.example). After making any changes, re-run `./dify-compose up -d` or `.\dify-compose.ps1 up -d` from the `docker` directory. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). ### Metrics Monitoring with Grafana diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py index 0eb43d66f4..e19ccd11e5 100644 --- a/api/extensions/ext_session_factory.py +++ b/api/extensions/ext_session_factory.py @@ -1,7 +1,9 @@ +from flask import Flask + from core.db.session_factory import configure_session_factory from extensions.ext_database import db -def init_app(app): +def init_app(app: Flask): with app.app_context(): configure_session_factory(db.engine) diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 66a25e5daf..b4482674da 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask: @pytest.fixture -def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]: +def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]: """ Request context fixture for containerized Flask application. @@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, @pytest.fixture -def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]: +def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]: """ Test client fixture for containerized Flask application. @@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli @pytest.fixture -def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]: +def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]: """ Database session fixture for containerized testing. diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 18755ef012..bb737754a1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -69,7 +70,7 @@ def _unwrap(func): class TestCompletionEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_completion_create_payload(self): @@ -86,7 +87,7 @@ class TestCompletionEndpoints: ) assert payload.query == "hi" - def test_completion_api_success(self, app, monkeypatch): + def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -116,7 +117,7 @@ class TestCompletionEndpoints: assert resp == {"result": {"text": "ok"}} - def test_completion_api_conversation_not_exists(self, app, monkeypatch): + def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -142,7 +143,7 @@ class TestCompletionEndpoints: with pytest.raises(NotFound): method(app_model=MagicMock(id="app-1")) - def test_completion_api_provider_not_initialized(self, app, monkeypatch): + def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -166,7 +167,7 @@ class TestCompletionEndpoints: with pytest.raises(completion_module.ProviderNotInitializeError): method(app_model=MagicMock(id="app-1")) - def test_completion_api_quota_exceeded(self, app, monkeypatch): + def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -193,10 +194,10 @@ class TestCompletionEndpoints: class TestAppEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppApi() method = _unwrap(api.put) payload = { @@ -234,7 +235,7 @@ class TestAppEndpoints: } ) - def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppIconApi() method = _unwrap(api.post) payload = { @@ -266,7 +267,7 @@ class TestAppEndpoints: class TestOpsTraceEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_ops_trace_query_basic(self): @@ -277,7 +278,7 @@ class TestOpsTraceEndpoints: payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"}) assert payload.tracing_config["api_key"] == "k" - def test_trace_app_config_get_empty(self, app, monkeypatch): + def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.get) @@ -292,7 +293,7 @@ class TestOpsTraceEndpoints: assert result == {"has_not_configured": True} - def test_trace_app_config_post_invalid(self, app, monkeypatch): + def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.post) @@ -309,7 +310,7 @@ class TestOpsTraceEndpoints: with pytest.raises(BadRequest): method(app_id="app-1") - def test_trace_app_config_delete_not_found(self, app, monkeypatch): + def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.delete) @@ -326,7 +327,7 @@ class TestOpsTraceEndpoints: class TestSiteEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_site_response_structure(self): @@ -337,7 +338,7 @@ class TestSiteEndpoints: payload = AppSiteUpdatePayload(default_language="en-US") assert payload.default_language == "en-US" - def test_app_site_update_post(self, app, monkeypatch): + def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSite() method = _unwrap(api.post) @@ -375,7 +376,7 @@ class TestSiteEndpoints: assert isinstance(result, dict) assert result["title"] == "My Site" - def test_app_site_access_token_reset(self, app, monkeypatch): + def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) @@ -427,7 +428,7 @@ class TestWorkflowEndpoints: class TestWorkflowAppLogEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_app_log_query(self): @@ -438,7 +439,7 @@ class TestWorkflowAppLogEndpoints: query = WorkflowAppLogQuery(detail="true") assert query.detail is True - def test_workflow_app_log_api_get(self, app, monkeypatch): + def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_app_log_module.WorkflowAppLogApi() method = _unwrap(api.get) @@ -477,14 +478,14 @@ class TestWorkflowAppLogEndpoints: class TestWorkflowDraftVariableEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_variable_creation(self): payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test") assert payload.name == "var1" - def test_workflow_variable_collection_get(self, app, monkeypatch): + def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_draft_variable_module.WorkflowVariableCollectionApi() method = _unwrap(api.get) @@ -529,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints: class TestWorkflowStatisticEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_statistic_time_range(self): @@ -541,7 +542,7 @@ class TestWorkflowStatisticEndpoints: assert query.start is None assert query.end is None - def test_workflow_daily_runs_statistic(self, app, monkeypatch): + def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -567,7 +568,7 @@ class TestWorkflowStatisticEndpoints: assert response.get_json() == {"data": [{"date": "2024-01-01"}]} - def test_workflow_daily_terminals_statistic(self, app, monkeypatch): + def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -598,7 +599,7 @@ class TestWorkflowStatisticEndpoints: class TestWorkflowTriggerEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_webhook_trigger_payload(self): @@ -608,7 +609,7 @@ class TestWorkflowTriggerEndpoints: enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True) assert enable_payload.enable_trigger is True - def test_webhook_trigger_api_get(self, app, monkeypatch): + def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_trigger_module.WebhookTriggerApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index 25d19cf35a..bcb6e41ef7 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from controllers.console.app import app_import as app_import_module from services.app_dsl_service import ImportStatus @@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: class TestAppImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -57,7 +58,7 @@ class TestAppImportApi: assert status == 400 assert response["status"] == ImportStatus.FAILED - def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -75,7 +76,7 @@ class TestAppImportApi: assert status == 202 assert response["status"] == ImportStatus.PENDING - def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -96,7 +97,7 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED - def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -121,7 +122,7 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED - def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -149,10 +150,10 @@ class TestAppImportApi: class TestAppImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportConfirmApi() method = _unwrap(api.post) @@ -172,10 +173,10 @@ class TestAppImportConfirmApi: class TestAppImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportCheckDependenciesApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 320da85b60..1fcce9ca44 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - app, + app: Flask, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False @@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi: mock_revoke, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} @@ -120,7 +121,7 @@ class TestEmailRegisterResetApi: mock_create_account, mock_login, mock_reset_login_rate, - app, + app: Flask, ): mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} mock_create_account.return_value = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index d2703ed5cc..014c1588fe 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} @@ -123,7 +124,7 @@ class TestForgotPasswordResetApi: mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 1eabb45422..01d88d247c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.mark.parametrize( @@ -65,7 +66,7 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -130,7 +131,7 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -394,7 +395,7 @@ class TestOAuthCallback: class TestAccountGeneration: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 50249bcd74..8d6b25b5b3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email.assert_called_once() @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask): """ Test password reset email blocked by IP rate limit. @@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi: mock_reset_rate_limit.assert_called_once_with("user@example.com") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask): """ Test code verification blocked by rate limit. @@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with invalid token. @@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with mismatched email. @@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with incorrect code. @@ -321,7 +322,7 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -375,7 +376,7 @@ class TestForgotPasswordResetApi: mock_revoke_token.assert_called_once_with("valid_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, app): + def test_reset_password_mismatch(self, mock_get_data, app: Flask): """ Test password reset with mismatched passwords. @@ -397,7 +398,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, app): + def test_reset_password_invalid_token(self, mock_get_data, app: Flask): """ Test password reset with invalid token. @@ -418,7 +419,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, app): + def test_reset_password_wrong_phase(self, mock_get_data, app: Flask): """ Test password reset with token not in reset phase. @@ -442,7 +443,7 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask): """ Test password reset for non-existent account. diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index d5ae95dfb7..2752e6b34f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from controllers.console import console_ns @@ -26,7 +27,7 @@ def unwrap(func): class TestPipelineTemplateListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -50,7 +51,7 @@ class TestPipelineTemplateListApi: class TestPipelineTemplateDetailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -115,7 +116,7 @@ class TestPipelineTemplateDetailApi: class TestCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_patch_success(self, app): @@ -193,7 +194,7 @@ class TestCustomizedPipelineTemplateApi: class TestPublishCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_post_success(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index 64e3de2ca3..7624c1150f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden import services @@ -24,13 +25,13 @@ def unwrap(func): class TestCreateRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _valid_payload(self): return {"yaml_content": "name: test"} - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi: assert status == 201 assert response == import_info - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(Forbidden): method(api) - def test_post_dataset_name_duplicate(self, app): + def test_post_dataset_name_duplicate(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi: class TestCreateEmptyRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) @@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi: assert status == 201 assert response == {"id": "ds-1"} - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index cb67892878..f238ca13ee 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( @@ -25,7 +26,7 @@ def unwrap(func): class TestRagPipelineImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _payload(self, mode="create"): @@ -128,7 +129,7 @@ class TestRagPipelineImportApi: class TestRagPipelineImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_confirm_success(self, app): @@ -190,7 +191,7 @@ class TestRagPipelineImportConfirmApi: class TestRagPipelineImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -219,7 +220,7 @@ class TestRagPipelineImportCheckDependenciesApi: class TestRagPipelineExportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_with_include_secret(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index c1f3122c2b..1fdb3057b8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound @@ -45,10 +46,10 @@ def unwrap(func): class TestDraftWorkflowApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_draft_success(self, app): + def test_get_draft_success(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -68,7 +69,7 @@ class TestDraftWorkflowApi: result = method(api, pipeline) assert result == workflow - def test_get_draft_not_exist(self, app): + def test_get_draft_not_exist(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -86,7 +87,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotExist): method(api, pipeline) - def test_sync_hash_not_match(self, app): + def test_sync_hash_not_match(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -111,7 +112,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotSync): method(api, pipeline) - def test_sync_invalid_text_plain(self, app): + def test_sync_invalid_text_plain(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -128,7 +129,7 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 - def test_restore_published_workflow_to_draft_success(self, app): + def test_restore_published_workflow_to_draft_success(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -155,7 +156,7 @@ class TestDraftWorkflowApi: assert result["result"] == "success" assert result["hash"] == "restored-hash" - def test_restore_published_workflow_to_draft_not_found(self, app): + def test_restore_published_workflow_to_draft_not_found(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -179,7 +180,7 @@ class TestDraftWorkflowApi: with pytest.raises(NotFound): method(api, pipeline, "published-workflow") - def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -211,10 +212,10 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_iteration_node_success(self, app): + def test_iteration_node_success(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -240,7 +241,7 @@ class TestDraftRunNodes: result = method(api, pipeline, "node") assert result == {"ok": True} - def test_iteration_node_conversation_not_exists(self, app): + def test_iteration_node_conversation_not_exists(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -262,7 +263,7 @@ class TestDraftRunNodes: with pytest.raises(NotFound): method(api, pipeline, "node") - def test_loop_node_success(self, app): + def test_loop_node_success(self, app: Flask): api = RagPipelineDraftRunLoopNodeApi() method = unwrap(api.post) @@ -290,10 +291,10 @@ class TestDraftRunNodes: class TestPipelineRunApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_draft_run_success(self, app): + def test_draft_run_success(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -325,7 +326,7 @@ class TestPipelineRunApis: ): assert method(api, pipeline) == {"ok": True} - def test_draft_run_rate_limit(self, app): + def test_draft_run_rate_limit(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -356,10 +357,10 @@ class TestPipelineRunApis: class TestDraftNodeRun: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_execution_not_found(self, app): + def test_execution_not_found(self, app: Flask): api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) @@ -387,7 +388,7 @@ class TestDraftNodeRun: class TestPublishedPipelineApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_publish_success(self, app, db_session_with_containers: Session): @@ -436,10 +437,10 @@ class TestPublishedPipelineApis: class TestMiscApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_task_stop(self, app): + def test_task_stop(self, app: Flask): api = RagPipelineTaskStopApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestMiscApis: stop_mock.assert_called_once() assert result["result"] == "success" - def test_transform_forbidden(self, app): + def test_transform_forbidden(self, app: Flask): api = RagPipelineTransformApi() method = unwrap(api.post) @@ -476,7 +477,7 @@ class TestMiscApis: with pytest.raises(Forbidden): method(api, "ds1") - def test_recommended_plugins(self, app): + def test_recommended_plugins(self, app: Flask): api = RagPipelineRecommendedPluginApi() method = unwrap(api.get) @@ -496,10 +497,10 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_published_run_success(self, app): + def test_published_run_success(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi: result = method(api, pipeline) assert result == {"ok": True} - def test_published_run_rate_limit(self, app): + def test_published_run_rate_limit(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi: class TestDefaultBlockConfigApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_block_config_success(self, app): + def test_get_block_config_success(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi: result = method(api, pipeline, "llm") assert result == {"k": "v"} - def test_get_block_config_invalid_json(self, app): + def test_get_block_config_invalid_json(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_published_workflows_success(self, app): + def test_get_published_workflows_success(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi: assert result["items"] == [{"id": "w1"}] assert result["has_more"] is False - def test_get_published_workflows_forbidden(self, app): + def test_get_published_workflows_forbidden(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi: class TestRagPipelineByIdApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -682,7 +683,7 @@ class TestRagPipelineByIdApi: assert result == workflow - def test_patch_no_fields(self, app): + def test_patch_no_fields(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -700,7 +701,7 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -720,7 +721,7 @@ class TestRagPipelineByIdApi: workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) - def test_delete_active_workflow_rejected(self, app): + def test_delete_active_workflow_rejected(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -733,10 +734,10 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_last_run_success(self, app): + def test_last_run_success(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi: result = method(api, pipeline, "node1") assert result == node_exec - def test_last_run_not_found(self, app): + def test_last_run_not_found(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_set_datasource_variables_success(self, app): + def test_set_datasource_variables_success(self, app: Flask): api = RagPipelineDatasourceVariableApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 1c4c6a899f..50ad92afa1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console.datasets import data_source @@ -51,7 +52,7 @@ def mock_engine(): class TestDataSourceApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): @@ -188,7 +189,7 @@ class TestDataSourceApi: class TestDataSourceNotionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_credential_not_found(self, app, patch_tenant): @@ -323,7 +324,7 @@ class TestDataSourceNotionListApi: class TestDataSourceNotionApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_preview_success(self, app, patch_tenant): @@ -381,7 +382,7 @@ class TestDataSourceNotionApi: class TestDataSourceNotionDatasetSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): @@ -424,7 +425,7 @@ class TestDataSourceNotionDatasetSyncApi: class TestDataSourceNotionDocumentSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 83492048ef..0b53ca5585 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.conversation as conversation_module @@ -53,7 +54,7 @@ def user(): class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, chat_app, user): @@ -108,7 +109,7 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_delete_success(self, app, chat_app, user): @@ -156,7 +157,7 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_rename_success(self, app, chat_app, user): @@ -197,7 +198,7 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_pin_success(self, app, chat_app, user): @@ -219,7 +220,7 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_unpin_success(self, app, chat_app, user): diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index f2e7104b18..d944613886 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -6,6 +6,7 @@ import json from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -60,7 +61,7 @@ def _mock_user_tenant(): @pytest.fixture -def client(flask_app_with_containers): +def client(flask_app_with_containers: Flask): return flask_app_with_containers.test_client() @@ -147,10 +148,10 @@ class TestUtils: class TestToolProviderListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ToolProviderListApi() method = unwrap(api.get) @@ -170,10 +171,10 @@ class TestToolProviderListApi: class TestBuiltinProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolBuiltinProviderListToolsApi() method = unwrap(api.get) @@ -190,7 +191,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == [{"a": 1}] - def test_info(self, app): + def test_info(self, app: Flask): api = ToolBuiltinProviderInfoApi() method = unwrap(api.get) @@ -207,7 +208,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"x": 1} - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolBuiltinProviderDeleteApi() method = unwrap(api.post) @@ -224,7 +225,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["result"] == "success" - def test_add_invalid_type(self, app): + def test_add_invalid_type(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -238,7 +239,7 @@ class TestBuiltinProviderApis: with pytest.raises(ValueError): method(api, "provider") - def test_add_success(self, app): + def test_add_success(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -257,7 +258,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["id"] == 1 - def test_update(self, app): + def test_update(self, app: Flask): api = ToolBuiltinProviderUpdateApi() method = unwrap(api.post) @@ -276,7 +277,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credentials(self, app): + def test_get_credentials(self, app: Flask): api = ToolBuiltinProviderGetCredentialsApi() method = unwrap(api.get) @@ -293,7 +294,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"k": "v"} - def test_icon(self, app): + def test_icon(self, app: Flask): api = ToolBuiltinProviderIconApi() method = unwrap(api.get) @@ -307,7 +308,7 @@ class TestBuiltinProviderApis: response = method(api, "provider") assert response.mimetype == "image/png" - def test_credentials_schema(self, app): + def test_credentials_schema(self, app: Flask): api = ToolBuiltinProviderCredentialsSchemaApi() method = unwrap(api.get) @@ -324,7 +325,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider", "oauth2") == {"schema": {}} - def test_set_default_credential(self, app): + def test_set_default_credential(self, app: Flask): api = ToolBuiltinProviderSetDefaultApi() method = unwrap(api.post) @@ -341,7 +342,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credential_info(self, app): + def test_get_credential_info(self, app: Flask): api = ToolBuiltinProviderGetCredentialInfoApi() method = unwrap(api.get) @@ -358,7 +359,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"info": "x"} - def test_get_oauth_client_schema(self, app): + def test_get_oauth_client_schema(self, app: Flask): api = ToolBuiltinProviderGetOauthClientSchemaApi() method = unwrap(api.get) @@ -378,10 +379,10 @@ class TestBuiltinProviderApis: class TestApiProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_add(self, app): + def test_add(self, app: Flask): api = ToolApiProviderAddApi() method = unwrap(api.post) @@ -406,7 +407,7 @@ class TestApiProviderApis: ): assert method(api)["id"] == 1 - def test_remote_schema(self, app): + def test_remote_schema(self, app: Flask): api = ToolApiProviderGetRemoteSchemaApi() method = unwrap(api.get) @@ -423,7 +424,7 @@ class TestApiProviderApis: ): assert method(api)["schema"] == "x" - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolApiProviderListToolsApi() method = unwrap(api.get) @@ -440,7 +441,7 @@ class TestApiProviderApis: ): assert method(api) == [{"tool": 1}] - def test_update(self, app): + def test_update(self, app: Flask): api = ToolApiProviderUpdateApi() method = unwrap(api.post) @@ -468,7 +469,7 @@ class TestApiProviderApis: ): assert method(api)["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolApiProviderDeleteApi() method = unwrap(api.post) @@ -485,7 +486,7 @@ class TestApiProviderApis: ): assert method(api)["result"] == "success" - def test_get(self, app): + def test_get(self, app: Flask): api = ToolApiProviderGetApi() method = unwrap(api.get) @@ -505,10 +506,10 @@ class TestApiProviderApis: class TestWorkflowApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create(self, app): + def test_create(self, app: Flask): api = ToolWorkflowProviderCreateApi() method = unwrap(api.post) @@ -534,7 +535,7 @@ class TestWorkflowApis: ): assert method(api)["id"] == 1 - def test_update_invalid(self, app): + def test_update_invalid(self, app: Flask): api = ToolWorkflowProviderUpdateApi() method = unwrap(api.post) @@ -560,7 +561,7 @@ class TestWorkflowApis: result = method(api) assert result["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolWorkflowProviderDeleteApi() method = unwrap(api.post) @@ -577,7 +578,7 @@ class TestWorkflowApis: ): assert method(api)["ok"] - def test_get_error(self, app): + def test_get_error(self, app: Flask): api = ToolWorkflowProviderGetApi() method = unwrap(api.get) @@ -594,10 +595,10 @@ class TestWorkflowApis: class TestLists: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_builtin_list(self, app): + def test_builtin_list(self, app: Flask): api = ToolBuiltinListApi() method = unwrap(api.get) @@ -617,7 +618,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_api_list(self, app): + def test_api_list(self, app: Flask): api = ToolApiListApi() method = unwrap(api.get) @@ -637,7 +638,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_workflow_list(self, app): + def test_workflow_list(self, app: Flask): api = ToolWorkflowListApi() method = unwrap(api.get) @@ -660,10 +661,10 @@ class TestLists: class TestLabels: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_labels(self, app): + def test_labels(self, app: Flask): api = ToolLabelsApi() method = unwrap(api.get) @@ -679,10 +680,10 @@ class TestLabels: class TestOAuth: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_no_client(self, app): + def test_oauth_no_client(self, app: Flask): api = ToolPluginOAuthApi() method = unwrap(api.get) @@ -700,7 +701,7 @@ class TestOAuth: with pytest.raises(Forbidden): method(api, "provider") - def test_oauth_callback_no_cookie(self, app): + def test_oauth_callback_no_cookie(self, app: Flask): api = ToolOAuthCallback() method = unwrap(api.get) @@ -711,10 +712,10 @@ class TestOAuth: class TestOAuthCustomClient: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_save_custom_client(self, app): + def test_save_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.post) @@ -731,7 +732,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider")["ok"] - def test_get_custom_client(self, app): + def test_get_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.get) @@ -748,7 +749,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider") == {"client_id": "x"} - def test_delete_custom_client(self, app): + def test_delete_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.delete) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index ca8195af53..6efdaf2943 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden from controllers.console.workspace.trigger_providers import ( @@ -45,7 +46,7 @@ def mock_user(): class TestTriggerProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_icon_success(self, app): @@ -93,7 +94,7 @@ class TestTriggerProviderApis: class TestTriggerSubscriptionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_list_success(self, app): @@ -128,7 +129,7 @@ class TestTriggerSubscriptionListApi: class TestTriggerSubscriptionBuilderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_create_builder(self, app): @@ -236,7 +237,7 @@ class TestTriggerSubscriptionBuilderApis: class TestTriggerSubscriptionCrud: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_update_rename_only(self, app): @@ -342,7 +343,7 @@ class TestTriggerSubscriptionCrud: class TestTriggerOAuthApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_oauth_authorize_success(self, app): @@ -480,7 +481,7 @@ class TestTriggerOAuthApis: class TestTriggerOAuthClientManageApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_client(self, app): @@ -556,7 +557,7 @@ class TestTriggerOAuthClientManageApi: class TestTriggerSubscriptionVerifyApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_verify_success(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 437b199ec2..5791d2f6e2 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -18,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound @@ -246,7 +247,7 @@ def _unwrap(method): @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): # Uses the full containerised app so that Flask config, extensions, and # blueprint registrations match production. Most tests mock the service # layer to isolate controller logic; a few (e.g. test_list_tags_from_db) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py index e1e6741014..c34da27ebe 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.web.conversation import ( @@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace: class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context("/conversations"): with pytest.raises(NotChatAppError): ConversationListApi().get(_completion_app(), _end_user()) @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") - def test_happy_path(self, mock_paginate: MagicMock, app) -> None: + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: conv_id = str(uuid4()) conv = SimpleNamespace( id=conv_id, @@ -65,16 +66,16 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}"): with pytest.raises(NotChatAppError): ConversationApi().delete(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.delete") - def test_delete_success(self, mock_delete: MagicMock, app) -> None: + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) @@ -83,7 +84,7 @@ class TestConversationApi: assert result["result"] == "success" @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) - def test_delete_not_found(self, mock_delete: MagicMock, app) -> None: + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): with pytest.raises(NotFound, match="Conversation Not Exists"): @@ -92,17 +93,17 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): with pytest.raises(NotChatAppError): ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.rename") @patch("controllers.web.conversation.web_ns") - def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "New Name", "auto_generate": False} conv = SimpleNamespace( @@ -126,7 +127,7 @@ class TestConversationRenameApi: side_effect=ConversationNotExistsError(), ) @patch("controllers.web.conversation.web_ns") - def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "X", "auto_generate": False} @@ -137,16 +138,16 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.pin") - def test_pin_success(self, mock_pin: MagicMock, app) -> None: + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) @@ -154,7 +155,7 @@ class TestConversationPinApi: assert result["result"] == "success" @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) - def test_pin_not_found(self, mock_pin: MagicMock, app) -> None: + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): with pytest.raises(NotFound): @@ -163,16 +164,16 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.unpin") - def test_unpin_success(self, mock_unpin: MagicMock, app) -> None: + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 635cfee2da..2c6a990240 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.web.forgot_password import ( ForgotPasswordCheckApi, @@ -29,7 +30,7 @@ def _patch_wraps(): class TestForgotPasswordSendEmailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") @@ -42,7 +43,7 @@ class TestForgotPasswordSendEmailApi: mock_rate_limit, mock_get_account, mock_send_mail, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -64,7 +65,7 @@ class TestForgotPasswordSendEmailApi: class TestForgotPasswordCheckApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") @@ -81,7 +82,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} @@ -117,7 +118,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} @@ -142,7 +143,7 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @@ -157,7 +158,7 @@ class TestForgotPasswordResetApi: mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() @@ -194,7 +195,7 @@ class TestForgotPasswordResetApi: mock_db, mock_token_bytes, mock_hash_password, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py index 19833cc772..de9e691434 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized @@ -182,7 +183,7 @@ class TestValidateUserAccessibility: class TestDecodeJwtToken: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True): diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index c342e8994b..bd13527e14 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -85,7 +85,7 @@ class TestPauseStatePersistenceLayerTestContainers: return WorkflowRunService(engine) @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): + def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus @@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers: generate_entity=entity, ) - def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): + def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session): """Test complete pause flow: event -> state serialization -> database save -> storage save.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert isinstance(persisted_entity, WorkflowAppGenerateEntity) assert persisted_entity.workflow_execution_id == self.test_workflow_run_id - def test_state_persistence_and_retrieval(self, db_session_with_containers): + def test_state_persistence_and_retrieval(self, db_session_with_containers: Session): """Test that pause state can be persisted and retrieved correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert retrieved_state["node_run_steps"] == 10 assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_database_transaction_handling(self, db_session_with_containers): + def test_database_transaction_handling(self, db_session_with_containers: Session): """Test that database transactions are handled correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert pause_model.resumed_at is None assert pause_model.state_object_key != "" - def test_file_storage_integration(self, db_session_with_containers): + def test_file_storage_integration(self, db_session_with_containers: Session): """Test integration with file storage system.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps() assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_workflow_with_different_creators(self, db_session_with_containers): + def test_workflow_with_different_creators(self, db_session_with_containers: Session): """Test pause state with workflows created by different users.""" # Arrange - Create workflow with different creator different_user_id = str(uuid.uuid4()) @@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers: resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id - def test_layer_ignores_non_pause_events(self, db_session_with_containers): + def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session): """Test that layer ignores non-pause events.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -562,7 +562,7 @@ class TestPauseStatePersistenceLayerTestContainers: ).all() assert len(pause_states) == 0 - def test_layer_requires_initialization(self, db_session_with_containers): + def test_layer_requires_initialization(self, db_session_with_containers: Session): """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index a60159c66a..54ee133bfe 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -15,6 +15,7 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client @@ -40,7 +41,7 @@ class TestTenantIsolatedTaskQueueIntegration: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -94,7 +95,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" assert queue._task_key == f"tenant_test-key_task:{tenant.id}" - def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake): + def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker): """Test that different tenants have isolated queues.""" tenant1, _ = test_tenant_and_account @@ -176,7 +177,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert len(remaining_tasks) == 2 assert remaining_tasks == ["task4", "task5"] - def test_push_and_pull_complex_objects(self, test_queue, fake): + def test_push_and_pull_complex_objects(self, test_queue, fake: Faker): """Test pushing and pulling complex object tasks.""" # Create complex task objects as dictionaries (not dataclass instances) tasks = [ @@ -218,7 +219,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert pulled_task["data"] == original_task["data"] assert pulled_task["metadata"] == original_task["metadata"] - def test_mixed_task_types(self, test_queue, fake): + def test_mixed_task_types(self, test_queue, fake: Faker): """Test pushing and pulling mixed string and object tasks.""" string_task = "simple_string_task" object_task = { @@ -267,7 +268,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Verify task key has expired assert test_queue.get_task_key() is None - def test_large_task_batch(self, test_queue, fake): + def test_large_task_batch(self, test_queue, fake: Faker): """Test handling large batches of tasks.""" # Create large batch of tasks large_batch = [] @@ -292,7 +293,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert isinstance(task, dict) assert task["index"] == i # FIFO order - def test_queue_operations_isolation(self, test_tenant_and_account, fake): + def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker): """Test concurrent operations on different queues.""" tenant, _ = test_tenant_and_account @@ -312,7 +313,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert tasks2 == ["task1_queue2", "task2_queue2"] assert tasks1 != tasks2 - def test_task_wrapper_serialization_roundtrip(self, test_queue, fake): + def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker): """Test TaskWrapper serialization and deserialization roundtrip.""" # Create complex nested data complex_data = { @@ -346,7 +347,7 @@ class TestTenantIsolatedTaskQueueIntegration: task = test_queue.pull_tasks(1) assert task[0] == invalid_json_task - def test_real_world_batch_processing_scenario(self, test_queue, fake): + def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker): """Test realistic batch processing scenario.""" # Simulate batch processing tasks batch_tasks = [] @@ -403,7 +404,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -435,7 +436,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return tenant, account - def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake): + def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker): """ Test compatibility with legacy queues containing only string data. @@ -465,7 +466,7 @@ class TestTenantIsolatedTaskQueueCompatibility: expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] assert pulled_tasks == expected_order - def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake): + def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker): """ Test complete migration scenario from legacy to new system. @@ -546,7 +547,7 @@ class TestTenantIsolatedTaskQueueCompatibility: assert task["tenant_id"] == tenant.id assert task["processing_type"] == "new_system" - def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake): + def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker): """ Test error recovery when legacy queue contains malformed data. diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 00d7496a40..9da6b04a2c 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw class TestGetAvailableDatasetsIntegration: def test_returns_datasets_with_available_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].name == dataset.name def test_filters_out_datasets_with_only_archived_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_only_disabled_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_non_completed_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_includes_external_datasets_without_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that external datasets are returned even with no available documents. @@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].id == dataset.id assert result[0].provider == "external" - def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): # Arrange fake = Faker() @@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].tenant_id == tenant1.id def test_returns_empty_list_when_no_datasets_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration: # Assert assert result == [] - def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies): + def test_returns_only_requested_dataset_ids( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): # Arrange fake = Faker() @@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration: class TestKnowledgeRetrievalIntegration: def test_knowledge_retrieval_with_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration: assert isinstance(result, list) def test_knowledge_retrieval_no_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration: assert result == [] def test_knowledge_retrieval_rate_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index 177fb95ff3..e71079829f 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_service import ApiKeyAuthService @@ -31,7 +32,7 @@ class TestApiKeyAuthService: def mock_args(self, category, provider, mock_credentials) -> dict: return {"category": category, "provider": provider, "credentials": mock_credentials} - def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False): binding = DataSourceApiKeyAuthBinding( tenant_id=tenant_id, category=category, @@ -44,7 +45,7 @@ class TestApiKeyAuthService: return binding def test_get_provider_auth_list_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() @@ -56,14 +57,16 @@ class TestApiKeyAuthService: assert len(tenant_results) == 1 assert tenant_results[0].provider == provider - def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_get_provider_auth_list_empty( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): result = ApiKeyAuthService.get_provider_auth_list(tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] def test_get_provider_auth_list_filters_disabled( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True @@ -78,7 +81,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_success( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -97,7 +106,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_validation_failed( - self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = False @@ -112,7 +121,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_encrypts_api_key( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -128,7 +143,13 @@ class TestApiKeyAuthService: mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) def test_get_auth_credentials_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + self, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + category, + provider, + mock_credentials, ): self._create_binding( db_session_with_containers, @@ -144,14 +165,14 @@ class TestApiKeyAuthService: assert result == mock_credentials def test_get_auth_credentials_not_found( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) assert result is None def test_get_auth_credentials_json_parsing( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} self._create_binding( @@ -169,7 +190,7 @@ class TestApiKeyAuthService: assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" def test_delete_provider_auth_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): binding = self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider @@ -183,7 +204,9 @@ class TestApiKeyAuthService: remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() assert remaining is None - def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_delete_provider_auth_not_found( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): # Should not raise when binding not found ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index f48c6da690..e78fa27976 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -10,6 +10,7 @@ from uuid import uuid4 import httpx import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory @@ -114,7 +115,7 @@ class TestAuthIntegration: assert result2[0].tenant_id == tenant_id_2 def test_cross_tenant_access_prevention( - self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category ): result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index 42d587b7f7..327f14ddfe 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType @@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument: "user_id": user_id, } - def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_waiting_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in waiting state. @@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument: mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True") def test_pause_document_indexing_state_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful pause of document in indexing state. @@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument: assert document.is_paused is True assert document.paused_by == mock_document_service_dependencies["user_id"] - def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_parsing_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in parsing state. @@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is True - def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_completed_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause completed document. @@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is False - def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_error_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause document in error state. @@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument: "recover_task": mock_task, } - def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_paused_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful recovery of paused document. @@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument: document.dataset_id, document.id ) - def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_not_paused_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to recover non-paused document. @@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument: "user_id": user_id, } - def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_single_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of single document. @@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument: dataset.id, [document.id], mock_document_service_dependencies["user_id"] ) - def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_multiple_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of multiple documents. @@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument: ) def test_retry_document_concurrent_retry_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is already being retried. @@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument: assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when current_user is missing. @@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: } def test_batch_update_document_status_enable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch enabling of documents. @@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: assert mock_document_service_dependencies["add_task"].delay.call_count == 2 def test_batch_update_document_status_disable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch disabling of documents. @@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_archive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch archiving of documents. @@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_unarchive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch unarchiving of documents. @@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_empty_list( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test handling of empty document list. @@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_not_called() def test_batch_update_document_status_document_indexing_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is being indexed. @@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument: "current_user": mock_current_user, } - def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies): """ Test successful document renaming. @@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument: assert result == document assert document.name == new_name - def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_built_in_fields( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with built-in fields enabled. @@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument: assert document.doc_metadata["document_name"] == new_name assert document.doc_metadata["existing_key"] == "existing_value" - def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_upload_file( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with associated upload file. @@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument: assert upload_file.name == new_name def test_rename_document_dataset_not_found_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when dataset is not found. @@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Dataset not found"): DocumentService.rename_document(dataset_id, document_id, new_name) - def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_not_found_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when document is not found. @@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Document not found"): DocumentService.rename_document(dataset.id, document_id, new_name) - def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_permission_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when user lacks permission. diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py index 4e8255d8ed..e73c2afe7f 100644 --- a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from redis import RedisError +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from models.account import TenantAccountJoin @@ -122,7 +123,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_multiple_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -144,7 +145,7 @@ class TestSyncAccountDeletion: assert queued_workspace_ids == set(tenant_ids) def test_sync_account_deletion_no_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: mock_config.ENTERPRISE_ENABLED = True @@ -155,7 +156,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_partial_failure( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -180,7 +181,7 @@ class TestSyncAccountDeletion: assert mock_queue_task.call_count == 3 def test_sync_account_deletion_all_failures( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_id = str(uuid4()) diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py index 2b842629a7..724dd19f92 100644 --- a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -3,6 +3,8 @@ from __future__ import annotations from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from models.model import App, RecommendedApp, Site from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_type import RecommendAppType @@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval: class TestFetchRecommendedAppsFromDb: - def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb: assert "assistant" in result["categories"] assert "writing" in result["categories"] - def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + def test_falls_back_to_default_language_when_empty( + self, flask_app_with_containers, db_session_with_containers: Session + ): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id in app_ids - def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_site(db_session_with_containers, app_id=app1.id) @@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id not in app_ids - def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb: class TestFetchRecommendedAppDetailFromDb: - def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session): result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) assert result is None - def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb: assert result is None @patch("services.recommend_app.database.database_retrieval.AppDslService") - def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 3ec265d009..f78037e503 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -2,6 +2,7 @@ import copy import pytest from faker import Faker +from sqlalchemy.orm import Session from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService: # for consistency with other test files return {} - def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_baichuan_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for Baichuan model. @@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_common_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for common models. @@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_prompt_case_insensitive_baichuan_detection( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan model detection is case insensitive. @@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text def test_get_common_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for chat app with completion mode. @@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_chat_app_chat_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation for chat app with chat mode. @@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with completion mode. @@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with chat mode. @@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService: assert CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation without context. @@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_common_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported app mode. @@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_common_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported model mode. @@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService: # Assert: Verify empty dict is returned assert result == {} - def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_completion_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test completion prompt generation with context. @@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService: assert result_text == CONTEXT + original_text def test_get_completion_prompt_without_context( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test completion prompt generation without context. @@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService: assert result_text == original_text assert CONTEXT not in result_text - def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation with context. @@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService: assert original_text in result_text assert result_text == CONTEXT + original_text - def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_without_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation without context. @@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService: assert CONTEXT not in result_text def test_get_baichuan_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with completion mode. @@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_chat_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with chat mode. @@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with completion mode. @@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with chat mode. @@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_baichuan_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test Baichuan prompt generation without context. @@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported app mode. @@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_baichuan_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported model mode. @@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_prompt_all_app_modes_common_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with common model. @@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService: assert result != {} def test_get_prompt_all_app_modes_baichuan_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with Baichuan model. @@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService: assert result is not None assert result != {} - def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test prompt generation with edge cases. @@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService: # Should either return a valid result or empty dict, but not crash assert result is not None - def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that original templates are not modified. @@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService: assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_baichuan_template_immutability( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that original Baichuan templates are not modified. @@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService: assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies): + def test_context_integration_consistency( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test consistency of context integration across different scenarios. @@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService: assert prompt_text.startswith(CONTEXT) def test_baichuan_context_integration_consistency( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test consistency of Baichuan context integration across different scenarios. diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 1835650c42..6b844615b5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -10,6 +10,8 @@ from uuid import uuid4 import pytest import yaml from faker import Faker +from flask import Flask +from sqlalchemy.orm import Session from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -88,7 +90,7 @@ class TestAppDslService: """Integration tests for AppDslService using testcontainers.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -129,7 +131,7 @@ class TestAppDslService: "enterprise_service": mock_enterprise_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): fake = Faker() with patch("services.account_service.FeatureService") as mock_account_feature_service: mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -206,7 +208,7 @@ class TestAppDslService: # ── Import: Validation ──────────────────────────────────────────── - def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers): + def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Invalid import_mode"): service.import_app( @@ -215,7 +217,7 @@ class TestAppDslService: yaml_content="version: '0.1.0'", ) - def test_import_app_missing_yaml_content(self, db_session_with_containers): + def test_import_app_missing_yaml_content(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -225,7 +227,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "yaml_content is required" in result.error - def test_import_app_missing_yaml_url(self, db_session_with_containers): + def test_import_app_missing_yaml_url(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -235,7 +237,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "yaml_url is required" in result.error - def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers): + def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -245,7 +247,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "content must be a mapping" in result.error - def test_import_app_version_not_str_returns_failed(self, db_session_with_containers): + def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) result = service.import_app( @@ -256,7 +258,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Invalid version type" in result.error - def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers): + def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -266,7 +268,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Missing app data" in result.error - def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): def bad_safe_load(_content: str): raise yaml.YAMLError("bad") @@ -281,7 +283,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert result.error.startswith("Invalid YAML format:") - def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): monkeypatch.setattr( AppDslService, "_create_or_update_app", @@ -299,7 +301,7 @@ class TestAppDslService: # ── Import: YAML URL ────────────────────────────────────────────── - def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): monkeypatch.setattr( app_dsl_service.ssrf_proxy, "get", @@ -315,7 +317,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Error fetching YAML from URL: boom" in result.error - def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch): response = MagicMock() response.content = b"" response.raise_for_status.return_value = None @@ -330,7 +332,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Empty content" in result.error - def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch): response = MagicMock() response.content = b"x" * (DSL_MAX_SIZE + 1) response.raise_for_status.return_value = None @@ -345,7 +347,9 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "File size exceeds" in result.error - def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_user_attachments_keeps_original_url( + self, db_session_with_containers: Session, monkeypatch + ): yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml" yaml_bytes = _pending_yaml_content() @@ -371,7 +375,7 @@ class TestAppDslService: assert result.imported_dsl_version == "99.0.0" assert requested_urls == [yaml_url] - def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch): yaml_url = "https://github.com/acme/repo/blob/main/app.yml" raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml" yaml_bytes = _pending_yaml_content() @@ -400,7 +404,7 @@ class TestAppDslService: # ── Import: App ID checks ──────────────────────────────────────── - def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers): + def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -412,7 +416,7 @@ class TestAppDslService: assert result.error == "App not found" def test_import_app_overwrite_only_allows_workflow_and_advanced_chat( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) assert app.mode == "chat" @@ -429,7 +433,7 @@ class TestAppDslService: # ── Import: Flow ────────────────────────────────────────────────── - def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers): + def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -449,7 +453,7 @@ class TestAppDslService: assert stored is not None def test_import_app_completed_uses_declared_dependencies( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) @@ -483,7 +487,7 @@ class TestAppDslService: @pytest.mark.parametrize("has_workflow", [True, False]) def test_import_app_legacy_versions_extract_dependencies( - self, db_session_with_containers, monkeypatch, has_workflow: bool + self, db_session_with_containers: Session, monkeypatch, has_workflow: bool ): monkeypatch.setattr( AppDslService, @@ -540,13 +544,13 @@ class TestAppDslService: # ── Confirm Import ──────────────────────────────────────────────── - def test_confirm_import_expired_returns_failed(self, db_session_with_containers): + def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.confirm_import(import_id=str(uuid4()), account=_account_mock()) assert result.status == ImportStatus.FAILED assert "expired" in result.error - def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch): + def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" @@ -579,7 +583,7 @@ class TestAppDslService: assert result.app_id == created_app.id assert redis_client.get(redis_key) is None - def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers): + def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123") @@ -589,7 +593,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "validation error" in result.error - def test_confirm_import_exception_returns_failed(self, db_session_with_containers): + def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json") @@ -600,13 +604,13 @@ class TestAppDslService: # ── Check Dependencies ──────────────────────────────────────────── - def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers): + def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) app_model = _app_stub() result = service.check_dependencies(app_model=app_model) assert result.leaked_dependencies == [] - def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch): + def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch): app_id = str(uuid4()) pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id) redis_client.setex( @@ -634,7 +638,9 @@ class TestAppDslService: result = service.check_dependencies(app_model=_app_stub(id=app_id)) assert len(result.leaked_dependencies) == 1 - def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_dependencies_with_real_app( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' @@ -650,12 +656,12 @@ class TestAppDslService: # ── Create/Update App ───────────────────────────────────────────── - def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers): + def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="loss app mode"): service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) - def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch): + def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch): fixed_now = object() monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) @@ -707,7 +713,7 @@ class TestAppDslService: assert app.icon_background == "#222222" assert app.updated_at is fixed_now - def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers): + def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session): account = _account_mock() account.current_tenant_id = None service = AppDslService(db_session_with_containers) @@ -719,7 +725,7 @@ class TestAppDslService: ) def test_create_or_update_app_creates_workflow_app_and_saves_dependencies( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) @@ -755,7 +761,7 @@ class TestAppDslService: stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}") assert stored is not None - def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers): + def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing workflow data"): service._create_or_update_app( @@ -764,7 +770,7 @@ class TestAppDslService: account=_account_mock(), ) - def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers): + def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing model_config"): service._create_or_update_app( @@ -774,7 +780,7 @@ class TestAppDslService: ) def test_create_or_update_app_chat_creates_model_config_and_sends_event( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.app_model_config_id = None @@ -793,7 +799,7 @@ class TestAppDslService: db_session_with_containers.expire_all() assert app.app_model_config_id is not None - def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers): + def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Invalid app mode"): service._create_or_update_app( @@ -870,7 +876,7 @@ class TestAppDslService: assert data["app"]["icon_type"] == "image" assert data["app"]["icon_background"] == "#FFEAD5" - def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) model_config = AppModelConfig( @@ -908,7 +914,9 @@ class TestAppDslService: assert "model_config" in exported_data assert "dependencies" in exported_data - def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_workflow_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" db_session_with_containers.commit() @@ -941,7 +949,9 @@ class TestAppDslService: assert "workflow" in exported_data assert "dependencies" in exported_data - def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_with_workflow_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" db_session_with_containers.commit() @@ -981,7 +991,7 @@ class TestAppDslService: assert "workflow" in exported_data def test_export_dsl_with_invalid_workflow_id_raises_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py index 768a8baee2..d0c07f0de8 100644 --- a/api/tests/test_containers_integration_tests/services/test_attachment_service.py +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -7,7 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound import services.attachment_service as attachment_service_module @@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService class TestAttachmentService: - def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile: upload_file = UploadFile( tenant_id=tenant_id or str(uuid4()), storage_type=StorageType.OPENDAL, @@ -60,7 +60,7 @@ class TestAttachmentService: with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): AttachmentService(session_factory=invalid_session_factory) - def test_should_return_base64_when_file_exists(self, db_session_with_containers): + def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session): upload_file = self._create_upload_file(db_session_with_containers) service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) @@ -70,7 +70,7 @@ class TestAttachmentService: assert result == base64.b64encode(b"binary-content").decode() mock_load.assert_called_once_with(upload_file.key) - def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session): service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) with patch.object(attachment_service_module.storage, "load_once") as mock_load: diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 8092c7ad75..4893126d7f 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from extensions.ext_redis import redis_client @@ -24,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache: """ @pytest.fixture(autouse=True) - def setup_redis_cleanup(self, flask_app_with_containers): + def setup_redis_cleanup(self, flask_app_with_containers: Flask): """Clean up Redis cache before and after each test.""" with flask_app_with_containers.app_context(): # Clean up before test @@ -56,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache: return value return None - def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -87,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was not called mock_get_plan_bulk.assert_not_called() - def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are not in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -127,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1 > 0 assert ttl_1 <= 600 # Should be <= 600 seconds - def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when some tenants are in cache, some are not.""" with flask_app_with_containers.app_context(): # Arrange @@ -158,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == missing_plan["tenant-3"] - def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask): """Test fallback to API when Redis mget fails.""" with flask_app_with_containers.app_context(): # Arrange @@ -189,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert cached_1 is not None assert cached_2 is not None - def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache contains invalid JSON.""" with flask_app_with_containers.app_context(): # Arrange @@ -241,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == expected_plans["tenant-3"] - def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache data doesn't match SubscriptionPlan schema.""" with flask_app_with_containers.app_context(): # Arrange @@ -274,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was called for tenant-2 and tenant-3 mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) - def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask): """Test that pipeline failure doesn't affect return value.""" with flask_app_with_containers.app_context(): # Arrange @@ -303,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify pipeline was attempted mock_pipeline.assert_called_once() - def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask): """Test with empty tenant_ids list.""" with flask_app_with_containers.app_context(): # Act @@ -321,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache: # But we should check that mget was not called at all # Since we can't easily verify this without more mocking, we just verify the result - def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask): """Test that expired cache keys are treated as cache misses.""" with flask_app_with_containers.app_context(): # Arrange diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 98c38f2b5f..8aa10129c1 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory: class TestConversationServicePagination: """Test conversation pagination operations.""" - def test_pagination_with_non_empty_include_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session): """ Test that non-empty include_ids filters properly. @@ -204,7 +205,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[0].id, conversations[1].id} - def test_pagination_with_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that empty exclude_ids doesn't filter. @@ -237,7 +238,7 @@ class TestConversationServicePagination: # Assert assert len(result.data) == len(conversations) - def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that non-empty exclude_ids filters properly. @@ -271,7 +272,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[2].id} - def test_pagination_with_sorting_descending(self, db_session_with_containers): + def test_pagination_with_sorting_descending(self, db_session_with_containers: Session): """ Test pagination with descending sort order. @@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation: within conversations. """ - def test_pagination_by_first_id_without_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session): """ Test message pagination without specifying first_id. @@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 3 # All 3 messages returned assert result.has_more is False # No more messages available (3 < limit of 10) - def test_pagination_by_first_id_with_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session): """ Test message pagination with first_id specified. @@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 2 # Only 2 messages returned after first_id assert result.has_more is False # No more messages available (2 < limit of 10) - def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers): + def test_pagination_by_first_id_raises_error_when_first_message_not_found( + self, db_session_with_containers: Session + ): """ Test that FirstMessageNotExistsError is raised when first_id doesn't exist. @@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation: limit=10, ) - def test_pagination_with_has_more_flag(self, db_session_with_containers): + def test_pagination_with_has_more_flag(self, db_session_with_containers: Session): """ Test that has_more flag is correctly set when there are more messages. @@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set - def test_pagination_with_ascending_order(self, db_session_with_containers): + def test_pagination_with_ascending_order(self, db_session_with_containers: Session): """ Test message pagination with ascending order. @@ -512,7 +515,7 @@ class TestConversationServiceSummarization: """ @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session): """ Test successful auto-generation of conversation name. @@ -552,7 +555,7 @@ class TestConversationServiceSummarization: app_model.tenant_id, first_message.query, conversation.id, app_model.id ) - def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers): + def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session): """ Test that MessageNotExistsError is raised when conversation has no messages. @@ -571,7 +574,9 @@ class TestConversationServiceSummarization: ConversationService.auto_generate_name(app_model, conversation) @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_handles_llm_failure_gracefully( + self, mock_llm_generator, db_session_with_containers: Session + ): """ Test that LLM generation failures are suppressed and don't crash. @@ -604,7 +609,7 @@ class TestConversationServiceSummarization: assert conversation.name == original_name # Name remains unchanged @patch("services.conversation_service.naive_utc_now") - def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers): + def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session): """ Test renaming conversation with manual name. @@ -638,7 +643,7 @@ class TestConversationServiceSummarization: assert conversation.updated_at == mock_time @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers): + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session): """ Test rename delegates to auto_generate_name when auto_generate is True. @@ -682,7 +687,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_from_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating annotation from existing message. @@ -721,7 +728,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_without_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating standalone annotation without message. @@ -753,7 +762,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test updating an existing annotation. @@ -800,7 +809,7 @@ class TestConversationServiceMessageAnnotation: mock_add_task.delay.assert_not_called() @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving paginated annotation list. @@ -836,7 +845,7 @@ class TestConversationServiceMessageAnnotation: assert result_total == 5 @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving annotations with keyword filtering. @@ -885,7 +894,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test direct annotation insertion without message reference. @@ -919,7 +928,7 @@ class TestConversationServiceExport: Tests retrieving conversation data for export purposes. """ - def test_get_conversation_success(self, db_session_with_containers): + def test_get_conversation_success(self, db_session_with_containers: Session): """Test successful retrieval of conversation.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -937,7 +946,7 @@ class TestConversationServiceExport: # Assert assert result == conversation - def test_get_conversation_not_found(self, db_session_with_containers): + def test_get_conversation_not_found(self, db_session_with_containers: Session): """Test ConversationNotExistsError when conversation doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -949,7 +958,7 @@ class TestConversationServiceExport: ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user) @patch("services.annotation_service.current_account_with_tenant") - def test_export_annotation_list(self, mock_current_account, db_session_with_containers): + def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session): """Test exporting all annotations for an app.""" # Arrange app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -977,7 +986,7 @@ class TestConversationServiceExport: # Assert assert len(result) == 10 - def test_get_message_success(self, db_session_with_containers): + def test_get_message_success(self, db_session_with_containers: Session): """Test successful retrieval of a message.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -1001,7 +1010,7 @@ class TestConversationServiceExport: # Assert assert result == message - def test_get_message_not_found(self, db_session_with_containers): + def test_get_message_not_found(self, db_session_with_containers: Session): """Test MessageNotExistsError when message doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -1012,7 +1021,7 @@ class TestConversationServiceExport: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4())) - def test_get_conversation_for_end_user(self, db_session_with_containers): + def test_get_conversation_for_end_user(self, db_session_with_containers: Session): """ Test retrieving conversation created by end user via API. @@ -1038,7 +1047,7 @@ class TestConversationServiceExport: assert result == conversation @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session): """ Test conversation deletion with async cleanup. @@ -1071,7 +1080,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_called_once_with(conversation_id) @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session): """ Test deletion is denied when conversation belongs to a different account. """ @@ -1102,7 +1111,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_not_called() @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers): + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session): """ Test that delete propagates exceptions and does not trigger the cleanup task. diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py index 0b7bd9ca64..6c292dbc4b 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -5,7 +5,8 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from sqlalchemy.orm import sessionmaker +from flask import Flask +from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db @@ -149,7 +150,7 @@ class ConversationServiceVariableIntegrationFactory: @pytest.fixture -def real_conversation_service_session_factory(flask_app_with_containers): +def real_conversation_service_session_factory(flask_app_with_containers: Flask): del flask_app_with_containers real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -162,7 +163,7 @@ def real_conversation_service_session_factory(flask_app_with_containers): class TestConversationServiceVariables: def test_get_conversational_variable_success( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -200,7 +201,7 @@ class TestConversationServiceVariables: assert result.has_more is False def test_get_conversational_variable_with_last_id( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -242,7 +243,7 @@ class TestConversationServiceVariables: assert result.has_more is False def test_get_conversational_variable_last_id_not_found_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -259,7 +260,7 @@ class TestConversationServiceVariables: ) def test_get_conversational_variable_sets_has_more( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -287,7 +288,7 @@ class TestConversationServiceVariables: assert result.has_more is True def test_update_conversation_variable_success( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -320,7 +321,7 @@ class TestConversationServiceVariables: assert result["updated_at"] == updated_at def test_update_conversation_variable_not_found_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -337,7 +338,7 @@ class TestConversationServiceVariables: ) def test_update_conversation_variable_type_mismatch_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -360,7 +361,7 @@ class TestConversationServiceVariables: ) def test_update_conversation_variable_integer_number_compatibility( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -390,7 +391,7 @@ class TestConversationServiceVariables: class TestConversationServicePaginationWithContainers: - def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers): + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) @@ -404,7 +405,7 @@ class TestConversationServicePaginationWithContainers: invoke_from=InvokeFrom.WEB_APP, ) - def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers): + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) base_time = datetime(2024, 1, 1, 8, 0, 0) @@ -442,7 +443,7 @@ class TestConversationServicePaginationWithContainers: assert newest.id != middle.id assert [conversation.id for conversation in result.data] == [oldest.id] - def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers): + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") @@ -462,7 +463,7 @@ class TestConversationServicePaginationWithContainers: assert alpha.id != beta.id assert [conversation.id for conversation in result.data] == [gamma.id] - def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers): + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) end_user = factory.create_end_user(db_session_with_containers, app) @@ -493,7 +494,7 @@ class TestConversationServicePaginationWithContainers: assert account_conversation.id != end_user_conversation.id assert [conversation.id for conversation in result.data] == [end_user_conversation.id] - def test_pagination_filters_to_account_console_source(self, db_session_with_containers): + def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) end_user = factory.create_end_user(db_session_with_containers, app) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 02ab3f8314..638a962f18 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pytest -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from extensions.ext_database import db from graphon.variables import StringVariable @@ -13,7 +13,12 @@ from services.conversation_variable_updater import ConversationVariableNotFoundE class TestConversationVariableUpdater: def _create_conversation_variable( - self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + self, + db_session_with_containers: Session, + *, + conversation_id: str, + variable: StringVariable, + app_id: str | None = None, ) -> ConversationVariable: row = ConversationVariable( id=variable.id, @@ -25,7 +30,7 @@ class TestConversationVariableUpdater: db_session_with_containers.commit() return row - def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="old value") self._create_conversation_variable( @@ -42,7 +47,7 @@ class TestConversationVariableUpdater: assert row is not None assert row.data == updated_variable.model_dump_json() - def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="value") updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) @@ -50,7 +55,7 @@ class TestConversationVariableUpdater: with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): updater.update(conversation_id=conversation_id, variable=variable) - def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session): updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) result = updater.flush() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 0f63d98642..09ba041244 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -3,6 +3,7 @@ from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.errors.error import QuotaExceededError from models import TenantCreditPool @@ -14,7 +15,7 @@ class TestCreditPoolService: def _create_tenant_id(self) -> str: return str(uuid4()) - def test_create_default_pool(self, db_session_with_containers): + def test_create_default_pool(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) @@ -25,7 +26,7 @@ class TestCreditPoolService: assert pool.quota_used == 0 assert pool.quota_limit > 0 - def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -35,17 +36,17 @@ class TestCreditPoolService: assert result.tenant_id == tenant_id assert result.pool_type == ProviderQuotaType.TRIAL - def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session): result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None - def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session): result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) assert result is False - def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -53,7 +54,7 @@ class TestCreditPoolService: assert result is True - def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) # Exhaust credits @@ -64,11 +65,11 @@ class TestCreditPoolService: assert result is False - def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session): with pytest.raises(QuotaExceededError, match="Credit pool not found"): CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) - def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) pool.quota_used = pool.quota_limit @@ -77,7 +78,7 @@ class TestCreditPoolService: with pytest.raises(QuotaExceededError, match="No credits remaining"): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) - def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) credits_required = 10 @@ -89,7 +90,7 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 71c8874f79..f9898e2cfa 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db @@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory: class TestDatasetPermissionServiceGetPartialMemberList: """Verify partial-member list reads against persisted DatasetPermission rows.""" - def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session): """ Test retrieving partial member list with multiple members. """ @@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 3 - def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session): """ Test retrieving partial member list with single member. """ @@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 1 - def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session): """ Test retrieving partial member list when no members exist. """ @@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: class TestDatasetPermissionServiceUpdatePartialMemberList: """Verify partial-member list updates against persisted DatasetPermission rows.""" - def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session): """ Test adding new partial members to a dataset. """ @@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {member_1.id, member_2.id} - def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session): """ Test replacing existing partial members with new ones. """ @@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {new_member_1.id, new_member_2.id} - def test_update_partial_member_list_empty_list(self, db_session_with_containers): + def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test updating with empty member list (clearing all members). """ @@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: class TestDatasetPermissionServiceClearPartialMemberList: """Verify partial-member clearing against persisted DatasetPermission rows.""" - def test_clear_partial_member_list_success(self, db_session_with_containers): + def test_clear_partial_member_list_success(self, db_session_with_containers: Session): """ Test successful clearing of partial member list. """ @@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test clearing partial member list when no members exist. """ @@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" - def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session): """Test that users from different tenants cannot access dataset.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other_user) - def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session): """Test that tenant owners can access any dataset regardless of permission level.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, owner) - def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session): """Test ONLY_ME permission allows only the dataset creator to access.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, creator) - def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session): """Test ONLY_ME permission denies access to non-creators.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other) - def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session): """Test ALL_TEAM permission allows any team member to access the dataset.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, member) - def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_with_permission_success( + self, db_session_with_containers: Session + ): """ Test that user with explicit permission can access partial_members dataset. """ @@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission: permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert user.id in permissions - def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_without_permission_error( + self, db_session_with_containers: Session + ): """ Test error when user without permission tries to access partial_members dataset. """ @@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session): """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f..e6ee896a52 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration: class TestDocumentServicePauseRecoverRetry: """Tests for pause/recover/retry orchestration using real DB and Redis.""" - def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"): factory = DatasetServiceIntegrationDataFactory account, tenant = factory.create_account_with_tenant(db_session_with_containers) dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) @@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry: db_session_with_containers.commit() return doc, account - def test_pause_document_success(self, db_session_with_containers): + def test_pause_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is not None redis_client.delete(cache_key) - def test_pause_document_invalid_status_error(self, db_session_with_containers): + def test_pause_document_invalid_status_error(self, db_session_with_containers: Session): from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry: with pytest.raises(DocumentIndexingError): DocumentService.pause_document(doc) - def test_recover_document_success(self, db_session_with_containers): + def test_recover_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is None recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) - def test_retry_document_indexing_success(self, db_session_with_containers): + def test_retry_document_indexing_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py index c486ff5613..08de79f4b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -6,6 +6,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin from services.dataset_service import DatasetService @@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset: permission="only_me", ) - def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session): tenant, _ = self._create_tenant_and_account(db_session_with_containers) mock_user = Mock(id=None) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 3cac964d89..c43a5d5978 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,8 @@ from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory: class TestDatasetServiceDeleteDataset: """Integration coverage for DatasetService.delete_dataset using testcontainers.""" - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset: dataset.pipeline_id, ) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """Return False without scheduling cleanup when the target dataset does not exist.""" # Arrange owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py index 1b4179c9c7..0603a1e27f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -6,6 +6,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -363,7 +364,7 @@ class TestDatasetServicePermissionsAndLifecycle: DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) - def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask): with flask_app_with_containers.app_context(): with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(str(uuid4()), True) @@ -473,7 +474,7 @@ class TestDatasetCollectionBindingServiceIntegration: assert persisted.type == "dataset" assert persisted.collection_name - def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask): with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae516..69c39b8bfb 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -6,6 +6,7 @@ from datetime import UTC, datetime, timedelta from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.commit() return run - def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None: + def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None: archive_log = WorkflowArchiveLog( tenant_id=run.tenant_id, app_id=run.app_id, @@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.add(archive_log) db_session_with_containers.commit() - def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session): deleter = ArchivedWorkflowRunDeletion() missing_run_id = str(uuid4()) @@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {missing_run_id} not found" - def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {run.id} is not archived" - def test_delete_batch_uses_repo(self, db_session_with_containers): + def test_delete_batch_uses_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) base_time = datetime.now(UTC) run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time) @@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion: ).all() assert remaining_runs == [] - def test_delete_run_calls_repo(self, db_session_with_containers): + def test_delete_run_calls_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion: deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None - def test_delete_run_dry_run(self, db_session_with_containers): + def test_delete_run_dry_run(self, db_session_with_containers: Session): """Dry run should return success without actually deleting.""" tenant_id = str(uuid4()) run = self._create_workflow_run( @@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(WorkflowRun, run_id) is not None - def test_delete_run_exception_returns_error(self, db_session_with_containers): + def test_delete_run_exception_returns_error(self, db_session_with_containers: Session): """Exception during deletion should return failure result.""" from unittest.mock import MagicMock, patch @@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == "Database error" - def test_delete_by_run_id_success(self, db_session_with_containers): + def test_delete_by_run_id_success(self, db_session_with_containers: Session): """Successfully delete an archived workflow run by ID.""" tenant_id = str(uuid4()) base_time = datetime.now(UTC) @@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() assert db_session_with_containers.get(WorkflowRun, run_id) is None - def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session): """_get_workflow_run_repo should return a cached repo on subsequent calls.""" deleter = ArchivedWorkflowRunDeletion() diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index cafabc939b..074d448aab 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user with custom user_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser: assert result.type == InvokeFrom.SERVICE_API assert result.is_anonymous is False - def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user without user_id uses default session.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser: # Verify _is_anonymous is set correctly (property always returns False) assert result._is_anonymous is True - def test_get_existing_end_user(self, db_session_with_containers, factory): + def test_get_existing_end_user(self, db_session_with_containers: Session, factory): """Test retrieving an existing end user.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_create_end_user_service_api_type(self, db_session_with_containers, factory): + def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory): """Test creating new end user with SERVICE_API type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.app_id == app_id assert result.session_id == user_id - def test_create_end_user_web_app_type(self, db_session_with_containers, factory): + def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory): """Test creating new end user with WEB_APP type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.WEB_APP @patch("services.end_user_service.logger") - def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory): + def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory): """Test upgrading legacy end user with different type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert "Upgrading legacy EndUser" in log_call @patch("services.end_user_service.logger") - def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory): + def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory): """Test retrieving existing end user with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.SERVICE_API mock_logger.info.assert_not_called() - def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory): + def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory): """Test creating anonymous user when user_id is None.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result._is_anonymous is True assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory): + def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory): """Test that query ordering prioritizes records with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.id == matching.id assert result.id != non_matching.id - def test_external_user_id_matches_session_id(self, db_session_with_containers, factory): + def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory): """Test that external_user_id is set to match session_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType: InvokeFrom.DEBUGGER, ], ) - def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_end_user_with_different_invoke_types( + self, db_session_with_containers: Session, invoke_type, factory + ): """Test creating end users with different InvokeFrom types.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) existing_user = factory.create_end_user( db_session_with_containers, @@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById: assert result is not None assert result.id == existing_user.id - def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) result = EndUserService.get_end_user_by_id( @@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch: def factory(self): return TestEndUserServiceFactory() - def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3): """Create multiple apps under the same tenant.""" first_app = factory.create_app_and_account(db_session_with_containers) tenant_id = first_app.tenant_id @@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch: all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() return tenant_id, all_apps - def test_create_batch_empty_app_ids(self, db_session_with_containers): + def test_create_batch_empty_app_ids(self, db_session_with_containers: Session): result = EndUserService.create_end_user_batch( type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" ) assert result == {} - def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch: assert result[app_id].session_id == user_id assert result[app_id].type == InvokeFrom.SERVICE_API - def test_create_batch_default_session_id(self, db_session_with_containers, factory): + def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] @@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch: assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID assert end_user._is_anonymous is True - def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] user_id = f"user-{uuid4()}" @@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch: assert len(result) == 2 - def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch: for app_id in app_ids: assert first_result[app_id].id == second_result[app_id].id - def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) user_id = f"user-{uuid4()}" @@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch: "invoke_type", [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], ) - def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) user_id = f"user-{uuid4()}" diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 315936d721..f78aeaf984 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from services.feature_service import ( @@ -81,7 +82,7 @@ class TestFeatureService: fake = Faker() return fake.uuid4() - def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful feature retrieval with billing and enterprise enabled. @@ -156,7 +157,7 @@ class TestFeatureService: tenant_id ) - def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval for sandbox plan with specific limitations. @@ -222,7 +223,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_knowledge_rate_limit_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful knowledge rate limit retrieval with billing enabled. @@ -255,7 +258,7 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system features retrieval with enterprise and marketplace enabled. @@ -332,7 +335,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_unauthenticated( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval for an unauthenticated user. @@ -386,7 +391,9 @@ class TestFeatureService: # Marketplace should be visible assert result.enable_marketplace is True - def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_basic_config( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with basic configuration (no enterprise). @@ -436,7 +443,9 @@ class TestFeatureService: # Verify plugin package size (uses default value from dify_config) assert result.max_plugin_package_size == 15728640 - def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_billing_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval when billing is disabled. @@ -492,7 +501,7 @@ class TestFeatureService: assert result.webapp_copyright_enabled is False def test_get_knowledge_rate_limit_billing_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test knowledge rate limit retrieval when billing is disabled. @@ -523,7 +532,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() - def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_enterprise_only( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with enterprise enabled but billing disabled. @@ -583,7 +594,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_not_called() def test_get_system_features_enterprise_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval when enterprise is disabled. @@ -640,7 +651,7 @@ class TestFeatureService: # Verify no enterprise service calls mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() - def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval without tenant ID (billing disabled). @@ -686,7 +697,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_info.assert_not_called() - def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_partial_billing_info( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with partial billing information. @@ -746,7 +759,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_vector_space( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case vector space configuration. @@ -807,7 +822,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_webapp_auth( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case webapp auth configuration. @@ -863,7 +878,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_members_quota( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case members quota configuration. @@ -924,7 +941,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_plugin_installation_permission_scopes( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with different plugin installation permission scopes. @@ -1023,7 +1040,7 @@ class TestFeatureService: assert result.plugin_installation_permission.restrict_to_marketplace_only is True def test_get_features_workspace_members_missing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members info is missing from enterprise. @@ -1064,7 +1081,9 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_inactive( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with inactive license. @@ -1117,7 +1136,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_system_features_partial_enterprise_info( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with partial enterprise information. @@ -1186,7 +1205,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_limits( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case limit values. @@ -1244,7 +1265,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_protocols( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case protocol values. @@ -1297,7 +1318,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_education( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case education configuration. @@ -1353,7 +1376,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_license_limitation_model_is_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test LicenseLimitationModel.is_available method with various scenarios. @@ -1394,7 +1417,7 @@ class TestFeatureService: assert exact_limit.is_available(3) is True def test_get_features_workspace_members_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members are disabled in enterprise. @@ -1433,7 +1456,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) - def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_expired( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with expired license. @@ -1486,7 +1511,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_docs_processing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case document processing configuration. @@ -1544,7 +1569,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_branding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case branding configuration. @@ -1606,7 +1631,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_annotation_quota( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case annotation quota configuration. @@ -1668,7 +1693,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_features_edge_case_documents_upload( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case documents upload settings. @@ -1733,7 +1758,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_license_lost( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features with lost license status. @@ -1784,7 +1809,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_education_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with education feature disabled. diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index ed75363f3b..ce63e7a71a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -6,6 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session from configs import dify_config from core.workflow.human_input_adapter import ( @@ -88,7 +89,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, flask_app_with_containers, db_session_with_containers): + def test_default(self, flask_app_with_containers, db_session_with_containers: Session): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account = Account(name="Test User", email="member@example.com") db_session_with_containers.add(account) @@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index b55a19eaa9..fffa82bf5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from models.dataset import Dataset, DatasetMetadataBinding, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -65,7 +66,7 @@ class TestMetadataPartialUpdate: yield account def test_partial_update_merges_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -92,7 +93,7 @@ class TestMetadataPartialUpdate: assert updated_doc.doc_metadata["new_key"] == "new_value" def test_full_update_replaces_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -119,7 +120,7 @@ class TestMetadataPartialUpdate: assert "existing_key" not in updated_doc.doc_metadata def test_partial_update_skips_existing_binding( - self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -159,7 +160,7 @@ class TestMetadataPartialUpdate: assert len(bindings) == 1 def test_rollback_called_on_commit_failure( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py index c146a5924b..5fa5de6d80 100644 --- a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from models.model import OAuthProviderApp @@ -25,7 +26,7 @@ from services.oauth_server import ( class TestOAuthServerServiceGetProviderApp: """DB-backed tests for get_oauth_provider_app.""" - def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp: app = OAuthProviderApp( app_icon="icon.png", client_id=client_id, @@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp: db_session_with_containers.commit() return app - def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session): client_id = f"client-{uuid4()}" created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) @@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp: assert result.client_id == client_id assert result.id == created.id - def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session): result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index 7036524918..2f20949611 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -8,6 +8,7 @@ from datetime import datetime from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore @@ -39,7 +40,7 @@ class TestWorkflowRunRestore: assert result["created_at"].month == 1 assert result["name"] == "test" - def test_restore_table_records_returns_rowcount(self, db_session_with_containers): + def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() record_id = str(uuid4()) @@ -65,7 +66,7 @@ class TestWorkflowRunRestore: restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id)) assert restored_pause is not None - def test_restore_table_records_unknown_table(self, db_session_with_containers): + def test_restore_table_records_unknown_table(self, db_session_with_containers: Session): """Unknown table names should be ignored gracefully.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 970da98c55..6d5c7380b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from flask import Flask +from sqlalchemy.orm import Session from werkzeug.datastructures import FileStorage from models.enums import AppTriggerStatus, AppTriggerType @@ -52,7 +53,7 @@ class TestWebhookService: } @pytest.fixture - def test_data(self, db_session_with_containers, mock_external_dependencies): + def test_data(self, db_session_with_containers: Session, mock_external_dependencies): """Create test data for webhook service tests.""" fake = Faker() @@ -160,7 +161,7 @@ class TestWebhookService: "app_trigger": app_trigger, } - def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers: Flask): """Test successful retrieval of webhook trigger and workflow.""" webhook_id = test_data["webhook_id"] @@ -175,7 +176,7 @@ class TestWebhookService: assert node_config["id"] == "webhook_node" assert node_config["data"].title == "Test Webhook" - def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers: Flask): """Test webhook trigger not found scenario.""" with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Webhook not found"): @@ -421,7 +422,9 @@ class TestWebhookService: assert result["files"] == {} - def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers): + def test_trigger_workflow_execution_success( + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask + ): """Test successful workflow execution trigger.""" webhook_data = { "method": "POST", @@ -452,7 +455,7 @@ class TestWebhookService: mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once() def test_trigger_workflow_execution_end_user_service_failure( - self, test_data, mock_external_dependencies, flask_app_with_containers + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask ): """Test workflow execution trigger when EndUserService fails.""" webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py index 85ce3a6ba6..69cde847f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy import select from sqlalchemy.orm import Session @@ -165,7 +166,7 @@ class WebhookServiceRelationshipFactory: class TestWebhookServiceLookupWithContainers: def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -182,7 +183,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -202,7 +203,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -222,7 +223,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -239,7 +240,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -275,7 +276,7 @@ class TestWebhookServiceLookupWithContainers: class TestWebhookServiceTriggerExecutionWithContainers: def test_trigger_workflow_execution_triggers_async_workflow_successfully( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -318,7 +319,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: assert trigger_args[2].root_node_id == webhook_trigger.node_id def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -354,7 +355,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: mock_mark_rate_limited.assert_called_once_with(tenant.id) def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -386,7 +387,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: class TestWebhookServiceRelationshipSyncWithContainers: def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -401,7 +402,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: WebhookService.sync_webhook_relationships(app, workflow) def test_sync_webhook_relationships_raises_when_lock_not_acquired( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -418,7 +419,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: WebhookService.sync_webhook_relationships(app, workflow) def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -455,7 +456,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None def test_sync_webhook_relationships_sets_redis_cache_for_new_record( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -481,7 +482,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert cached_payload["webhook_id"] == "cache-webhook-id-00001" def test_sync_webhook_relationships_logs_when_lock_release_fails( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 1e57b5603d..a2cdddad61 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1530,7 +1530,7 @@ class TestWorkflowAppService: assert result_cross_tenant["total"] == 0 def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1543,7 +1543,7 @@ class TestWorkflowAppService: ) def test_get_paginate_workflow_app_logs_filters_by_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1558,7 +1558,9 @@ class TestWorkflowAppService: assert result["total"] >= 0 assert isinstance(result["data"], list) - def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_archive_logs( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7..82fe391b08 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -45,7 +45,9 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, fake: Faker | None = None + ): """ Helper method to create a test app with realistic data for testing. @@ -80,7 +82,7 @@ class TestWorkflowDraftVariableService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index b5ce8a53de..9ba1fda08b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -12,7 +12,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models import Account, App, Workflow +from models import Account, AccountStatus, App, TenantStatus, Workflow from models.model import AppMode from models.workflow import WorkflowType from services.workflow_service import WorkflowService @@ -33,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -49,7 +49,7 @@ class TestWorkflowService: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", # Set interface language for Site creation ) account.created_at = fake.date_time_this_year() @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -77,7 +77,7 @@ class TestWorkflowService: return account - def _create_test_app(self, db_session_with_containers: Session, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test app with realistic data. @@ -109,7 +109,7 @@ class TestWorkflowService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py index 29e1e240b4..afc4908c15 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -100,7 +100,7 @@ class TestWorkflowDeletion: session.flush() return provider - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -118,7 +118,7 @@ class TestWorkflowDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(Workflow, workflow_id) is None - def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + def test_delete_draft_workflow_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -130,7 +130,7 @@ class TestWorkflowDeletion: with pytest.raises(DraftWorkflowDeletionError): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -144,7 +144,7 @@ class TestWorkflowDeletion: with pytest.raises(WorkflowInUseError, match="currently in use by app"): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135..32b76c3469 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -64,7 +64,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: db_session_with_containers.commit() return execution - def test_get_node_last_execution_found(self, db_session_with_containers): + def test_get_node_last_execution_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it exists.""" # Arrange tenant_id = str(uuid4()) @@ -110,7 +110,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result.id == expected.id assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - def test_get_node_last_execution_not_found(self, db_session_with_containers): + def test_get_node_last_execution_not_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it doesn't exist.""" # Arrange tenant_id = str(uuid4()) @@ -129,7 +129,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_get_executions_by_workflow_run_empty(self, db_session_with_containers): + def test_get_executions_by_workflow_run_empty(self, db_session_with_containers: Session): """Test getting executions for a workflow run when none exist.""" # Arrange tenant_id = str(uuid4()) @@ -147,7 +147,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == [] - def test_get_execution_by_id_found(self, db_session_with_containers): + def test_get_execution_by_id_found(self, db_session_with_containers: Session): """Test getting execution by ID when it exists.""" # Arrange execution = self._create_execution( @@ -170,7 +170,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is not None assert result.id == execution.id - def test_get_execution_by_id_not_found(self, db_session_with_containers): + def test_get_execution_by_id_not_found(self, db_session_with_containers: Session): """Test getting execution by ID when it doesn't exist.""" # Arrange repository = self._create_repository(db_session_with_containers) @@ -182,7 +182,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_delete_expired_executions(self, db_session_with_containers): + def test_delete_expired_executions(self, db_session_with_containers: Session): """Test deleting expired executions.""" # Arrange tenant_id = str(uuid4()) @@ -248,7 +248,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_2_id not in remaining_ids assert kept_execution_id in remaining_ids - def test_delete_executions_by_app(self, db_session_with_containers): + def test_delete_executions_by_app(self, db_session_with_containers: Session): """Test deleting executions by app.""" # Arrange tenant_id = str(uuid4()) @@ -313,7 +313,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert deleted_2_id not in remaining_ids assert kept_id in remaining_ids - def test_get_expired_executions_batch(self, db_session_with_containers): + def test_get_expired_executions_batch(self, db_session_with_containers: Session): """Test getting expired executions batch for backup.""" # Arrange tenant_id = str(uuid4()) @@ -370,7 +370,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_1.id in result_ids assert old_execution_2.id in result_ids - def test_delete_executions_by_ids(self, db_session_with_containers): + def test_delete_executions_by_ids(self, db_session_with_containers: Session): """Test deleting executions by IDs.""" # Arrange tenant_id = str(uuid4()) @@ -424,7 +424,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: ).all() assert remaining == [] - def test_delete_executions_by_ids_empty_list(self, db_session_with_containers): + def test_delete_executions_by_ids_empty_list(self, db_session_with_containers: Session): """Test deleting executions with empty ID list.""" # Arrange repository = self._create_repository(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 7e5c374b5d..1c8d5969e0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -71,7 +71,7 @@ class TestCleanNotionDocumentTask: yield mock_factory def test_clean_notion_document_task_success( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test successful cleanup of Notion documents with proper database operations. @@ -176,7 +176,7 @@ class TestCleanNotionDocumentTask: # 5. The task completes without errors def test_clean_notion_document_task_dataset_not_found( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior when dataset is not found. @@ -196,7 +196,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.return_value.init_index_processor.assert_not_called() def test_clean_notion_document_task_empty_document_list( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior with empty document list. @@ -240,7 +240,7 @@ class TestCleanNotionDocumentTask: assert args[1] == [] def test_clean_notion_document_task_with_different_index_types( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with different dataset index types. @@ -328,7 +328,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.reset_mock() def test_clean_notion_document_task_with_segments_no_index_node_ids( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments that have no index_node_ids. @@ -411,7 +411,7 @@ class TestCleanNotionDocumentTask: # are properly deleted from the database. def test_clean_notion_document_task_partial_document_cleanup( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with partial document cleanup scenario. @@ -513,7 +513,7 @@ class TestCleanNotionDocumentTask: # The database operations work correctly, isolating only the specified documents. def test_clean_notion_document_task_with_mixed_segment_statuses( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments in different statuses. @@ -603,7 +603,7 @@ class TestCleanNotionDocumentTask: # IndexProcessor verification would require more sophisticated mocking. def test_clean_notion_document_task_continues_when_index_processor_fails( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Index processor failure (e.g. transient billing API error propagated via @@ -707,7 +707,7 @@ class TestCleanNotionDocumentTask: assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 def test_clean_notion_document_task_with_large_number_of_documents( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with a large number of documents and segments. @@ -806,7 +806,7 @@ class TestCleanNotionDocumentTask: # The database efficiently handles large-scale deletions. def test_clean_notion_document_task_with_documents_from_different_tenants( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents from different tenants. @@ -918,7 +918,7 @@ class TestCleanNotionDocumentTask: # Only documents from the target dataset are affected, maintaining tenant separation. def test_clean_notion_document_task_with_documents_in_different_states( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents in different indexing states. @@ -1024,7 +1024,7 @@ class TestCleanNotionDocumentTask: # All documents are deleted regardless of their indexing status. def test_clean_notion_document_task_with_documents_having_metadata( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents that have rich metadata. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 9084667c31..80289c448a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker from sqlalchemy import delete +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client @@ -25,7 +26,7 @@ class TestCreateSegmentToIndexTask: """Integration tests for create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database and Redis before each test to ensure isolation.""" # Clear all test data using fixture session @@ -55,7 +56,7 @@ class TestCreateSegmentToIndexTask: "index_processor": mock_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -102,7 +103,7 @@ class TestCreateSegmentToIndexTask: return account, tenant - def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + def _create_test_dataset_and_document(self, db_session_with_containers: Session, tenant_id, account_id): """ Helper method to create a test dataset and document for testing. @@ -151,7 +152,13 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING + self, + db_session_with_containers: Session, + dataset_id, + document_id, + tenant_id, + account_id, + status=SegmentStatus.WAITING, ): """ Helper method to create a test document segment for testing. @@ -189,7 +196,9 @@ class TestCreateSegmentToIndexTask: return segment - def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful creation of segment to index. @@ -225,7 +234,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_segment_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent segment ID. @@ -246,7 +255,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_invalid_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with invalid status. @@ -277,7 +286,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated dataset. @@ -330,7 +341,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated document. @@ -367,7 +380,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with disabled document. @@ -403,7 +416,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_archived( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with archived document. @@ -439,7 +452,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_indexing_incomplete( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with document that has incomplete indexing. @@ -475,7 +488,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_processor_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of index processor exceptions. @@ -511,7 +524,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_with_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with custom keywords. @@ -543,7 +556,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with different document forms. @@ -586,7 +599,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) def test_create_segment_to_index_performance_timing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing performance and timing. @@ -617,7 +630,7 @@ class TestCreateSegmentToIndexTask: assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test concurrent execution of segment indexing tasks. @@ -654,7 +667,7 @@ class TestCreateSegmentToIndexTask: assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 def test_create_segment_to_index_large_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with large content. @@ -703,7 +716,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_redis_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing when Redis operations fail. @@ -743,7 +756,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 1 def test_create_segment_to_index_database_transaction_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with database transaction handling. @@ -775,7 +788,7 @@ class TestCreateSegmentToIndexTask: assert segment.error is not None def test_create_segment_to_index_metadata_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with metadata validation. @@ -817,7 +830,7 @@ class TestCreateSegmentToIndexTask: assert doc is not None def test_create_segment_to_index_status_transition_flow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test complete status transition flow during indexing. @@ -852,7 +865,7 @@ class TestCreateSegmentToIndexTask: assert segment.indexing_at <= segment.completed_at def test_create_segment_to_index_with_empty_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with empty or minimal content. @@ -894,7 +907,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with special characters and unicode content. @@ -940,7 +953,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_long_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with long keyword lists. @@ -974,7 +987,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with proper tenant isolation. @@ -1017,7 +1030,7 @@ class TestCreateSegmentToIndexTask: assert segment1.tenant_id != segment2.tenant_id def test_create_segment_to_index_with_none_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with None keywords parameter. @@ -1048,7 +1061,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_comprehensive_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Comprehensive integration test covering multiple scenarios. diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 684097851b..a5a3cd10b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -174,11 +175,11 @@ class TestDatasetIndexingTaskIntegration: return dataset, documents - def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: + def _query_document(self, db_session_with_containers: Session, document_id: str) -> Document | None: """Return the latest persisted document state.""" return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) - def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: + def _assert_documents_parsing(self, db_session_with_containers: Session, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" db_session_with_containers.expire_all() for document_id in document_ids: @@ -212,7 +213,9 @@ class TestDatasetIndexingTaskIntegration: assert len(opened) >= 2 assert opened_ids <= closed_ids - def test_legacy_document_indexing_task_still_works(self, db_session_with_containers, patched_external_dependencies): + def test_legacy_document_indexing_task_still_works( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Ensure the legacy task entrypoint still updates parsing status.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -225,7 +228,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_multiple_documents(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_multiple_documents( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Process multiple documents in one batch.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -240,7 +245,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == len(document_ids) self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_with_limit_check(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_with_limit_check( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Reject batches larger than configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. @@ -263,7 +270,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "batch upload limit") def test_batch_processing_sandbox_plan_single_document_only( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Reject multi-document upload under sandbox plan.""" # Arrange @@ -280,7 +287,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() self._assert_documents_error_contains(db_session_with_containers, document_ids, "does not support batch upload") - def test_batch_processing_empty_document_list(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_empty_document_list( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle empty list input without failing.""" # Arrange dataset, _ = self._create_test_dataset_and_documents(db_session_with_containers, document_count=0) @@ -292,7 +301,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_tenant_queue_dispatches_next_task_after_completion( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch the next queued task after current tenant task completes. @@ -337,7 +346,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_not_called() def test_tenant_queue_deletes_running_key_when_no_follow_up_tasks( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Delete tenant running flag when queue has no pending tasks. @@ -362,7 +371,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_called_once() def test_validation_failure_sets_error_status_when_vector_space_at_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Set error status when vector space validation fails before runner phase.""" # Arrange @@ -382,7 +391,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "over the limit") def test_runner_exception_does_not_crash_indexing_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Catch generic runner exceptions without crashing the task.""" # Arrange @@ -397,7 +406,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_document_paused_error_handling(self, db_session_with_containers, patched_external_dependencies): + def test_document_paused_error_handling(self, db_session_with_containers: Session, patched_external_dependencies): """Handle DocumentIsPausedError and keep persisted state consistent.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -424,7 +433,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() def test_tenant_queue_error_handling_still_processes_next_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Even on current task failure, enqueue the next waiting tenant task. @@ -491,7 +500,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_all_opened_sessions_closed(session_close_tracker) def test_multiple_documents_with_mixed_success_and_failure( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Process only existing documents when request includes missing ids.""" # Arrange @@ -508,7 +517,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, existing_ids) def test_tenant_queue_dispatches_up_to_concurrency_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch only up to configured concurrency under queued backlog burst. @@ -543,7 +552,7 @@ class TestDatasetIndexingTaskIntegration: assert task_dispatch_spy.apply_async.call_count == concurrency_limit assert set_waiting_spy.call_count == concurrency_limit - def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies): + def test_task_queue_fifo_ordering(self, db_session_with_containers: Session, patched_external_dependencies): """Keep FIFO ordering when dispatching next queued tasks. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -576,7 +585,9 @@ class TestDatasetIndexingTaskIntegration: call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {}) assert call_kwargs.get("document_ids") == expected_task["document_ids"] - def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies): + def test_billing_disabled_skips_limit_checks( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Skip limit checks when billing feature is disabled.""" # Arrange large_document_ids = [str(uuid.uuid4()) for _ in range(100)] @@ -595,7 +606,7 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 100 self._assert_documents_parsing(db_session_with_containers, large_document_ids) - def test_complete_workflow_normal_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_normal_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end normal queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -618,7 +629,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_complete_workflow_priority_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_priority_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end priority queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -641,7 +652,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_single_document_processing(self, db_session_with_containers, patched_external_dependencies): + def test_single_document_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process the minimum batch size (single document).""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=1) @@ -655,7 +666,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 1 self._assert_documents_parsing(db_session_with_containers, [document_id]) - def test_document_with_special_characters_in_id(self, db_session_with_containers, patched_external_dependencies): + def test_document_with_special_characters_in_id( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle standard UUID ids with hyphen characters safely.""" # Arrange special_document_id = str(uuid.uuid4()) @@ -670,7 +683,9 @@ class TestDatasetIndexingTaskIntegration: # Assert self._assert_documents_parsing(db_session_with_containers, [special_document_id]) - def test_zero_vector_space_limit_allows_unlimited(self, db_session_with_containers, patched_external_dependencies): + def test_zero_vector_space_limit_allows_unlimited( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Treat vector limit 0 as unlimited and continue indexing.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -689,7 +704,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) def test_negative_vector_space_values_handled_gracefully( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Treat negative vector limits as non-blocking and continue indexing.""" # Arrange @@ -708,7 +723,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_large_document_batch_processing(self, db_session_with_containers, patched_external_dependencies): + def test_large_document_batch_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process a batch exactly at configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 48fec441c5..e4cbb9e589 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -55,7 +56,7 @@ class TestDealDatasetVectorIndexTask: yield mock_factory @pytest.fixture - def account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """Create an account with an owner tenant for testing. Returns a tuple of (account, tenant) where tenant is guaranteed to be non-None. @@ -73,7 +74,7 @@ class TestDealDatasetVectorIndexTask: return account, tenant def test_deal_dataset_vector_index_task_remove_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful removal of dataset vector index. @@ -131,7 +132,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail def test_deal_dataset_vector_index_task_add_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful addition of dataset vector index. @@ -233,7 +234,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_update_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful update of dataset vector index. @@ -337,7 +338,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_dataset_not_found_error( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior when dataset is not found. @@ -357,7 +358,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when no documents exist for the dataset. @@ -389,7 +390,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_segments( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when documents exist but have no segments. @@ -447,7 +448,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_update_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test update action when no documents exist for the dataset. @@ -480,7 +481,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_with_exception_handling( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action with exception handling during processing. @@ -578,7 +579,7 @@ class TestDealDatasetVectorIndexTask: assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with custom index type (QA_INDEX). @@ -656,7 +657,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_default_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with default index type (PARAGRAPH_INDEX). @@ -734,7 +735,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_multiple_documents_processing( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task processing with multiple documents and segments. @@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.load.call_count == 3 def test_deal_dataset_vector_index_task_document_status_transitions( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test document status transitions during task execution. @@ -938,7 +939,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with disabled documents. @@ -1061,7 +1062,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_archived_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with archived documents. @@ -1184,7 +1185,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_incomplete_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with documents that have incomplete indexing status. diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 8a69707b38..f4a71040c1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -11,9 +11,19 @@ import logging from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, Document, DocumentSegment, Tenant +from models import ( + Account, + AccountStatus, + Dataset, + DatasetPermissionEnum, + Document, + DocumentSegment, + Tenant, + TenantStatus, +) from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -37,7 +47,7 @@ class TestDeleteSegmentFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_tenant(self, db_session_with_containers, fake=None): + def _create_test_tenant(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant with realistic data. @@ -49,7 +59,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status=TenantStatus.NORMAL) tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -58,7 +68,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return tenant - def _create_test_account(self, db_session_with_containers, tenant, fake=None): + def _create_test_account(self, db_session_with_containers: Session, tenant, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -75,7 +85,7 @@ class TestDeleteSegmentFromIndexTask: name=fake.name(), email=fake.email(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) account.id = fake.uuid4() @@ -86,7 +96,9 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return account - def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + def _create_test_dataset( + self, db_session_with_containers: Session, tenant: Tenant, account: Account, fake: Faker | None = None + ): """ Helper method to create a test dataset with realistic data. @@ -106,7 +118,7 @@ class TestDeleteSegmentFromIndexTask: dataset.name = f"Test Dataset {fake.word()}" dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" - dataset.permission = "only_me" + dataset.permission = DatasetPermissionEnum.ONLY_ME dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' @@ -122,7 +134,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None, **kwargs): """ Helper method to create a test document with realistic data. @@ -172,7 +184,14 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return document - def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + def _create_test_document_segments( + self, + db_session_with_containers: Session, + document: Document, + account: Account, + count: int = 3, + fake: Faker | None = None, + ): """ Helper method to create test document segments with realistic data. @@ -218,7 +237,9 @@ class TestDeleteSegmentFromIndexTask: return segments @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) - def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + def test_delete_segment_from_index_task_success( + self, mock_index_processor_factory, db_session_with_containers: Session + ): """ Test successful segment deletion from index with comprehensive verification. @@ -267,7 +288,7 @@ class TestDeleteSegmentFromIndexTask: assert call_args[1]["with_keywords"] is True assert call_args[1]["delete_child_chunks"] is True - def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers: Session): """ Test task behavior when dataset is not found. @@ -288,7 +309,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found - def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers: Session): """ Test task behavior when document is not found. @@ -314,7 +335,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document not found - def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers: Session): """ Test task behavior when document is disabled. @@ -342,7 +363,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled - def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers: Session): """ Test task behavior when document is archived. @@ -370,7 +391,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is archived - def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers: Session): """ Test task behavior when document indexing is not completed. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6e03bd9351..6bfb1e1f1e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -13,7 +13,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, DocumentSegment +from models import Account, AccountStatus, Dataset, DocumentSegment, TenantAccountRole, TenantStatus from models import Document as DatasetDocument from models.dataset import DatasetProcessRule from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus @@ -35,7 +35,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -51,24 +51,23 @@ class TestDisableSegmentsFromIndexTask: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) - account.id = fake.uuid4() # monkey-patch attributes for test setup + account.updated_at = fake.date_time_this_year() + account.created_at = fake.date_time_this_year() + account.role = TenantAccountRole.OWNER + account.id = fake.uuid4() account.tenant_id = fake.uuid4() account.type = "normal" - account.role = "owner" - account.created_at = fake.date_time_this_year() - account.updated_at = account.created_at - # Create a tenant for the account from models.account import Tenant tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() @@ -83,7 +82,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None): """ Helper method to create a test dataset with realistic data. @@ -117,7 +116,9 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): + def _create_test_document( + self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None + ): """ Helper method to create a test document with realistic data. @@ -216,7 +217,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None): """ Helper method to create a dataset process rule. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index b6e7e6e5c9..77cd259833 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy import delete, func, select, update +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -162,7 +163,7 @@ class TestDocumentIndexingSyncTask: "indexing_runner": indexing_runner, } - def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None): + def _create_notion_sync_context(self, db_session_with_containers: Session, *, data_source_info: dict | None = None): account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset( db_session_with_containers, @@ -206,7 +207,7 @@ class TestDocumentIndexingSyncTask: "notion_info": notion_info, } - def test_document_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task handles missing document gracefully.""" # Arrange dataset_id = str(uuid4()) @@ -219,7 +220,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_workspace_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_workspace_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -235,7 +236,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_page_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_page_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -251,7 +252,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies): + def test_empty_data_source_info(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) @@ -264,7 +265,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_credential_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task sets document error state when credential is missing.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -284,7 +285,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies): + def test_page_not_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task exits early when notion page is unchanged.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -310,7 +311,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies): + def test_successful_sync_when_page_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test full successful sync flow with SQL state updates and side effects.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -349,7 +350,7 @@ class TestDocumentIndexingSyncTask: assert len(run_documents) == 1 assert getattr(run_documents[0], "id", None) == context["document"].id - def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies): + def test_dataset_not_found_during_cleaning(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task still updates document and reindexes if dataset vanishes before clean.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -376,7 +377,9 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies): + def test_cleaning_error_continues_to_indexing( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that indexing continues when index cleanup fails.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -400,7 +403,9 @@ class TestDocumentIndexingSyncTask: assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_document_paused_error( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that DocumentIsPausedError does not flip document into error state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -418,7 +423,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None - def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_general_error(self, db_session_with_containers: Session, mock_external_dependencies): """Test that indexing errors are persisted to document state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index cf1a8666f3..6c1454b6d8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.document_task import DocumentTask from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( @@ -51,7 +52,7 @@ class TestDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -71,14 +72,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -133,7 +134,7 @@ class TestDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -153,14 +154,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -221,7 +222,9 @@ class TestDocumentIndexingTasks: return dataset, documents - def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with multiple documents. @@ -262,7 +265,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 3 def test_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -286,7 +289,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -332,7 +335,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def test_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -373,7 +376,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test processing documents with mixed initial states. @@ -456,7 +459,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 4 def test_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -518,7 +521,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner"].assert_not_called() def test_document_indexing_task_billing_disabled_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful processing when billing is disabled. @@ -554,7 +557,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of DocumentIsPausedError from IndexingRunner. @@ -597,7 +600,9 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== - def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_old_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test document_indexing_task basic functionality. @@ -619,7 +624,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_normal_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_document_indexing_task basic functionality. @@ -643,7 +648,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_priority_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_document_indexing_task basic functionality. @@ -667,7 +672,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_document_indexing_with_tenant_queue_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with no waiting tasks. @@ -717,7 +722,7 @@ class TestDocumentIndexingTasks: mock_task_func.delay.assert_not_called() def test_document_indexing_with_tenant_queue_with_waiting_tasks( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis. @@ -776,7 +781,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 1 def test_document_indexing_with_tenant_queue_error_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling in _document_indexing_with_tenant_queue using real Redis. @@ -848,7 +853,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 0 def test_document_indexing_with_tenant_queue_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation in _document_indexing_with_tenant_queue using real Redis. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index a9a8c0f30c..208fc1aa1d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,9 +3,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -33,7 +34,7 @@ class TestDocumentIndexingUpdateTask: "runner_instance": runner_instance, } - def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2): + def _create_dataset_document_with_segments(self, db_session_with_containers: Session, *, segment_count: int = 2): fake = Faker() # Account and tenant @@ -41,12 +42,12 @@ class TestDocumentIndexingUpdateTask: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=fake.company(), status="normal") + tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -114,7 +115,7 @@ class TestDocumentIndexingUpdateTask: return dataset, document, node_ids - def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies): + def test_cleans_segments_and_reindexes(self, db_session_with_containers: Session, mock_external_dependencies): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Act @@ -153,7 +154,9 @@ class TestDocumentIndexingUpdateTask: first = run_docs[0] assert getattr(first, "id", None) == document.id - def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies): + def test_clean_error_is_logged_and_indexing_continues( + self, db_session_with_containers: Session, mock_external_dependencies + ): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Force clean to raise; task should continue to indexing @@ -173,7 +176,7 @@ class TestDocumentIndexingUpdateTask: ) assert remaining > 0 - def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found_noop(self, db_session_with_containers: Session, mock_external_dependencies): fake = Faker() # Act with non-existent document id document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 39c58987fd..12440f3e6b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -62,7 +63,7 @@ class TestDuplicateDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -145,7 +146,11 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_segments( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + document_count=3, + segments_per_doc=2, ): """ Helper method to create a test dataset with documents and segments. @@ -197,7 +202,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents, segments def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -287,7 +292,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _test_duplicate_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful duplicate document indexing with multiple documents. @@ -329,7 +334,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 3 def _test_duplicate_document_indexing_task_with_segment_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test duplicate document indexing with existing segments that need cleanup. @@ -379,7 +384,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def _test_duplicate_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -404,7 +409,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["index_processor"].clean.assert_not_called() def test_duplicate_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -450,7 +455,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def _test_duplicate_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -491,7 +496,7 @@ class TestDuplicateDocumentIndexingTasks: assert updated_document.processing_started_at is not None def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -554,7 +559,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for vector space limit. @@ -596,7 +601,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_duplicate_document_indexing_task_with_empty_document_list( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of empty document list. @@ -622,7 +627,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_deprecated_duplicate_document_indexing_task_delegates_to_core( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that deprecated duplicate_document_indexing_task delegates to core function. @@ -655,7 +660,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_duplicate_document_indexing_task with tenant isolation queue. @@ -698,7 +703,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_duplicate_document_indexing_task with tenant isolation queue. @@ -742,7 +747,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant queue wrapper processes next queued tasks. @@ -789,7 +794,7 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_not_called() def test_successful_duplicate_document_indexing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test successful duplicate document indexing flow.""" self._test_duplicate_document_indexing_task_success( @@ -797,7 +802,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when dataset is not found.""" self._test_duplicate_document_indexing_task_dataset_not_found( @@ -805,7 +810,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing with billing enabled and sandbox plan.""" self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( @@ -813,7 +818,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when billing limit is exceeded.""" self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( @@ -821,7 +826,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_runner_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when IndexingRunner raises an error.""" self._test_duplicate_document_indexing_task_indexing_runner_exception( @@ -829,7 +834,7 @@ class TestDuplicateDocumentIndexingTasks: ) def _test_duplicate_document_indexing_task_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" # Arrange @@ -860,7 +865,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_duplicate_document_indexing_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" self._test_duplicate_document_indexing_task_document_is_paused( @@ -868,7 +873,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_cleans_old_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test that duplicate document indexing cleans old segments.""" self._test_duplicate_document_indexing_task_with_segment_cleanup( diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py index 177af266fb..a697878bb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestMailChangeMailTask: "get_email_i18n_service": mock_get_email_i18n_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -72,7 +73,7 @@ class TestMailChangeMailTask: return account def test_send_change_mail_task_success_old_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for old_email phase. @@ -103,7 +104,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_success_new_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for new_email phase. @@ -134,7 +135,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when mail service is not initialized. @@ -159,7 +160,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_not_called() def test_send_change_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when email service raises an exception. @@ -191,7 +192,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email completed notification task execution. @@ -224,7 +225,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when mail service is not initialized. @@ -247,7 +248,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_email.assert_not_called() def test_send_change_mail_completed_notification_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when email service raises an exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 8343711998..8e9da6aaaa 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import delete +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -37,7 +38,7 @@ class TestSendEmailCodeLoginMailTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -71,7 +72,7 @@ class TestSendEmailCodeLoginMailTask: "email_service_instance": mock_email_service_instance, } - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account for testing. @@ -98,7 +99,7 @@ class TestSendEmailCodeLoginMailTask: return account - def _create_test_tenant_and_account(self, db_session_with_containers, fake=None): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant and account for testing. @@ -138,7 +139,7 @@ class TestSendEmailCodeLoginMailTask: return account, tenant def test_send_email_code_login_mail_task_success_english( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in English. @@ -182,7 +183,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_chinese( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in Chinese. @@ -221,7 +222,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_multiple_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending with multiple languages. @@ -261,7 +262,7 @@ class TestSendEmailCodeLoginMailTask: assert call_args[1]["template_context"]["code"] == test_codes[i] def test_send_email_code_login_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when mail service is not initialized. @@ -299,7 +300,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_not_called() def test_send_email_code_login_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when email service raises an exception. @@ -346,7 +347,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_invalid_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with invalid parameters. @@ -388,7 +389,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_called_once() def test_send_email_code_login_mail_task_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with edge cases and boundary conditions. @@ -451,7 +452,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with database integration. @@ -497,7 +498,7 @@ class TestSendEmailCodeLoginMailTask: assert account.status == "active" def test_send_email_code_login_mail_task_redis_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with Redis integration. @@ -541,7 +542,7 @@ class TestSendEmailCodeLoginMailTask: redis_client.delete(cache_key) def test_send_email_code_login_mail_task_error_handling_comprehensive( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error handling for email code login mail task. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 95a867dbb5..f505361727 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from sqlalchemy import delete +from sqlalchemy.orm import Session from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -172,7 +173,9 @@ def _create_workflow_pause_state( db_session_with_containers.commit() -def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): +def test_dispatch_human_input_email_task_integration( + monkeypatch: pytest.MonkeyPatch, db_session_with_containers: Session +): tenant, account = _create_workspace_member(db_session_with_containers) workflow_run_id = str(uuid.uuid4()) workflow_id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py index 1a20b6deec..f8e54ea9e6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from tasks.mail_inner_task import send_inner_email_task @@ -51,7 +52,7 @@ class TestMailInnerTask: }, } - def test_send_inner_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful email sending with valid data. @@ -90,7 +91,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_single_recipient(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_single_recipient( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with single recipient. @@ -126,7 +129,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_empty_substitutions(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_empty_substitutions( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with empty substitutions. @@ -163,7 +168,7 @@ class TestMailInnerTask: ) def test_send_inner_email_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when mail service is not initialized. @@ -193,7 +198,7 @@ class TestMailInnerTask: mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() def test_send_inner_email_template_rendering_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when template rendering fails. @@ -222,7 +227,9 @@ class TestMailInnerTask: # Verify no email service calls due to exception mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() - def test_send_inner_email_service_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_service_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending when email service fails. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index d34828c4b1..c8c7a4d961 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import delete, select +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from libs.email_i18n import EmailType @@ -42,7 +43,7 @@ class TestMailInviteMemberTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" # Clear all test data db_session_with_containers.execute(delete(TenantAccountJoin)) @@ -78,7 +79,7 @@ class TestMailInviteMemberTask: "config": mock_config, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -147,7 +148,7 @@ class TestMailInviteMemberTask: redis_client.setex(cache_key, 24 * 60 * 60, json.dumps(invitation_data)) # 24 hours return token - def _create_pending_account_for_invitation(self, db_session_with_containers, email, tenant): + def _create_pending_account_for_invitation(self, db_session_with_containers: Session, email, tenant): """ Helper method to create a pending account for invitation testing. @@ -185,7 +186,9 @@ class TestMailInviteMemberTask: return account - def test_send_invite_member_mail_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_invite_member_mail_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful invitation email sending with all parameters. @@ -231,7 +234,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_different_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test invitation email sending with different language codes. @@ -263,7 +266,7 @@ class TestMailInviteMemberTask: assert call_args[1]["language_code"] == language def test_send_invite_member_mail_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test behavior when mail service is not initialized. @@ -292,7 +295,7 @@ class TestMailInviteMemberTask: mock_email_service.send_email.assert_not_called() def test_send_invite_member_mail_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when email service raises an exception. @@ -322,7 +325,7 @@ class TestMailInviteMemberTask: assert "Send invite member mail to %s failed" in error_call def test_send_invite_member_mail_template_context_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test template context contains all required fields for email rendering. @@ -368,7 +371,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_integration_with_redis_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test integration with Redis token validation. @@ -407,7 +410,7 @@ class TestMailInviteMemberTask: assert invitation_data["workspace_id"] == tenant.id def test_send_invite_member_mail_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending with special characters in names and workspace names. @@ -449,7 +452,7 @@ class TestMailInviteMemberTask: assert template_context["workspace_name"] == workspace_name def test_send_invite_member_mail_real_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test real database integration with actual invitation flow. @@ -501,7 +504,7 @@ class TestMailInviteMemberTask: assert tenant_join.role == TenantAccountRole.NORMAL def test_send_invite_member_mail_token_lifecycle_management( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test token lifecycle management and validation. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py index e08b099480..176645a4ab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py @@ -11,6 +11,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -44,7 +45,7 @@ class TestMailOwnerTransferTask: "get_email_service": mock_get_email_service, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create test account and tenant for testing. @@ -86,7 +87,9 @@ class TestMailOwnerTransferTask: return account, tenant - def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_owner_transfer_confirm_task_success( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """ Test successful owner transfer confirmation email sending. @@ -127,7 +130,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_owner_transfer_confirm_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test owner transfer confirmation email when mail service is not initialized. @@ -158,7 +161,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_owner_transfer_confirm_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in owner transfer confirmation email. @@ -192,7 +195,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_old_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful old owner transfer notification email sending. @@ -234,7 +237,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email def test_send_old_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test old owner transfer notification email when mail service is not initialized. @@ -265,7 +268,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_old_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in old owner transfer notification email. @@ -299,7 +302,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_new_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful new owner transfer notification email sending. @@ -338,7 +341,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_new_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test new owner transfer notification email when mail service is not initialized. @@ -367,7 +370,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_new_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in new owner transfer notification email. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py index cced6f7780..071971f324 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist @@ -35,7 +36,7 @@ class TestMailRegisterTask: "get_email_service": mock_get_email_service, } - def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_success(self, db_session_with_containers: Session, mock_mail_dependencies): """Test successful email registration mail sending.""" fake = Faker() language = "en-US" @@ -56,7 +57,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test email registration task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -66,7 +67,9 @@ class TestMailRegisterTask: mock_mail_dependencies["get_email_service"].assert_not_called() mock_mail_dependencies["email_service"].send_email.assert_not_called() - def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_exception_handling( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """Test email registration task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") @@ -79,7 +82,7 @@ class TestMailRegisterTask: mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) def test_send_email_register_mail_task_when_account_exist_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test successful email registration mail sending when account exists.""" fake = Faker() @@ -105,7 +108,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_when_account_exist_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -118,7 +121,7 @@ class TestMailRegisterTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_email_register_mail_task_when_account_exist_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index f01fcc1742..5eea985fdc 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,12 +4,13 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from flask import Flask from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Pipeline from models.workflow import Workflow from tasks.rag_pipeline.priority_rag_pipeline_run_task import ( @@ -69,14 +70,14 @@ class TestRagPipelineRunTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -725,7 +726,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test successful run_single_rag_pipeline_task execution. @@ -760,7 +761,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -805,7 +806,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with non-existent database entities. diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index b43b622870..03c02ea341 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -3,6 +3,7 @@ from unittest.mock import ANY, call, patch import pytest from sqlalchemy import delete, func, select +from sqlalchemy.orm import Session from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType @@ -117,7 +118,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: class TestDeleteDraftVariablesBatch: - def test_delete_draft_variables_batch_success(self, db_session_with_containers): + def test_delete_draft_variables_batch_success(self, db_session_with_containers: Session): """Test successful deletion of draft variables in batches.""" _, app1 = _create_tenant_and_app(db_session_with_containers) _, app2 = _create_tenant_and_app(db_session_with_containers) @@ -137,7 +138,7 @@ class TestDeleteDraftVariablesBatch: assert app1_remaining_count == 0 assert app2_remaining_count == 100 - def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers): + def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers: Session): """Test deletion when no draft variables exist for the app.""" result = delete_draft_variables_batch(str(uuid.uuid4()), 1000) @@ -176,7 +177,7 @@ class TestDeleteDraftVariableOffloadData: """Test the Offload data cleanup functionality.""" @patch("extensions.ext_storage.storage") - def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers): + def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers: Session): """Test successful deletion of offload data.""" tenant, app = _create_tenant_and_app(db_session_with_containers) offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3) diff --git a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py index 34a1941c39..6365207661 100644 --- a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py +++ b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py @@ -1,12 +1,14 @@ from pathlib import Path +import pytest + from extensions.storage.opendal_storage import OpenDALStorage class TestOpenDALFsDefaultRoot: """Test that OpenDALStorage with scheme='fs' works correctly when no root is provided.""" - def test_fs_without_root_uses_default(self, tmp_path, monkeypatch): + def test_fs_without_root_uses_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When no root is specified, the default 'storage' should be used and passed to the Operator.""" # Change to tmp_path so the default "storage" dir is created there monkeypatch.chdir(tmp_path) @@ -25,7 +27,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_default_root.txt") - def test_fs_with_explicit_root(self, tmp_path): + def test_fs_with_explicit_root(self, tmp_path: Path): """When root is explicitly provided, it should be used.""" custom_root = str(tmp_path / "custom_storage") storage = OpenDALStorage(scheme="fs", root=custom_root) @@ -38,7 +40,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_explicit_root.txt") - def test_fs_with_env_var_root(self, tmp_path, monkeypatch): + def test_fs_with_env_var_root(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs.""" env_root = str(tmp_path / "env_storage") monkeypatch.setenv("OPENDAL_FS_ROOT", env_root) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b00d827e37..6402e7da2b 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -175,7 +175,7 @@ class TestWorkflowPauseIntegration: """Comprehensive integration tests for workflow pause functionality.""" @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers): + def setup_test_data(self, db_session_with_containers: Session): """Set up test data for each test method using TestContainers.""" # Create test tenant and account diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 19a41b6186..a5086b4c5d 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): """Test class for JavaScript code executor functionality.""" - def test_javascript_plain(self, flask_app_with_containers): + def test_javascript_plain(self, flask_app_with_containers: Flask): """Test basic JavaScript code execution with console.log output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result_message == "Hello World\n" - def test_javascript_json(self, flask_app_with_containers): + def test_javascript_json(self, flask_app_with_containers: Flask): """Test JavaScript code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result == '{"Hello":"World"}\n' - def test_javascript_with_code_template(self, flask_app_with_containers): + def test_javascript_with_code_template(self, flask_app_with_containers: Flask): """Test JavaScript workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports JavascriptCodeProvider, _ = self.javascript_imports @@ -37,7 +39,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_javascript_get_runner_script(self, flask_app_with_containers): + def test_javascript_get_runner_script(self, flask_app_with_containers: Flask): """Test JavaScript template transformer runner script generation""" _, NodeJsTemplateTransformer = self.javascript_imports diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index ddb079f00c..8b4c3c3d4a 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -1,12 +1,14 @@ import base64 +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJinja2CodeExecutor(CodeExecutorTestMixin): """Test class for Jinja2 code executor functionality.""" - def test_jinja2(self, flask_app_with_containers): + def test_jinja2(self, flask_app_with_containers: Flask): """Test basic Jinja2 template execution with variable substitution""" CodeExecutor, CodeLanguage = self.code_executor_imports _, Jinja2TemplateTransformer = self.jinja2_imports @@ -25,7 +27,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == "<>Hello World<>\n" - def test_jinja2_with_code_template(self, flask_app_with_containers): + def test_jinja2_with_code_template(self, flask_app_with_containers: Flask): """Test Jinja2 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -34,7 +36,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "Hello World"} - def test_jinja2_get_runner_script(self, flask_app_with_containers): + def test_jinja2_get_runner_script(self, flask_app_with_containers: Flask): """Test Jinja2 template transformer runner script generation""" _, Jinja2TemplateTransformer = self.jinja2_imports @@ -43,7 +45,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - def test_jinja2_template_with_special_characters(self, flask_app_with_containers): + def test_jinja2_template_with_special_characters(self, flask_app_with_containers: Flask): """ Test that templates with special characters (quotes, newlines) render correctly. This is a regression test for issue #26818 where textarea pre-fill values diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py index 6d93df2472..0de41e1312 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestPython3CodeExecutor(CodeExecutorTestMixin): """Test class for Python3 code executor functionality.""" - def test_python3_plain(self, flask_app_with_containers): + def test_python3_plain(self, flask_app_with_containers: Flask): """Test basic Python3 code execution with print output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == "Hello World\n" - def test_python3_json(self, flask_app_with_containers): + def test_python3_json(self, flask_app_with_containers: Flask): """Test Python3 code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == '{"Hello": "World"}\n' - def test_python3_with_code_template(self, flask_app_with_containers): + def test_python3_with_code_template(self, flask_app_with_containers: Flask): """Test Python3 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports Python3CodeProvider, _ = self.python3_imports @@ -37,7 +39,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_python3_get_runner_script(self, flask_app_with_containers): + def test_python3_get_runner_script(self, flask_app_with_containers: Flask): """Test Python3 template transformer runner script generation""" _, Python3TemplateTransformer = self.python3_imports diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index d3e864a75a..78413a0798 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -67,7 +67,7 @@ class TestActivateCheckApi: assert response["data"]["email"] == "invitee@example.com" @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_invalid_invitation_token(self, mock_get_invitation, app): + def test_check_invalid_invitation_token(self, mock_get_invitation, app: Flask): """ Test checking invalid invitation token. @@ -227,7 +227,7 @@ class TestActivateApi: mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_activation_with_invalid_token(self, mock_get_invitation, app): + def test_activation_with_invalid_token(self, mock_get_invitation, app: Flask): """ Test account activation with invalid token. diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index b7bc73da5f..7b2c7569fe 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -140,7 +140,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") - def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending blocked by IP rate limit. @@ -160,7 +160,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.login.AccountService.get_user_through_email") - def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app): + def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending to frozen account. @@ -353,7 +353,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app): + def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app: Flask): """ Test email code login with invalid token. @@ -375,7 +375,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app): + def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app: Flask): """ Test email code login with mismatched email. @@ -397,7 +397,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app): + def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app: Flask): """ Test email code login with incorrect code. diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index d089be8905..5284f29eed 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -9,7 +9,7 @@ This module tests the core authentication endpoints including: """ import base64 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from flask import Flask @@ -52,12 +52,12 @@ class TestLoginApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(LoginApi, "/login") return app.test_client() @@ -97,7 +97,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -141,14 +141,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") def test_successful_login_with_valid_invitation( self, - mock_reset_rate_limit, + mock_reset_rate_limit: Mock, mock_login, mock_get_tenants, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -188,7 +188,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login rejection when rate limit is exceeded. @@ -216,7 +216,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @patch("controllers.console.auth.login.BillingService.is_email_in_freeze") - def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app): + def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask): """ Test login rejection for frozen accounts. @@ -253,7 +253,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, ): """ Test login failure with invalid credentials. @@ -290,7 +290,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( - self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app + self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask ): """ Test login rejection for banned accounts. @@ -328,14 +328,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.FeatureService.get_system_features") def test_login_fails_when_no_workspace_and_limit_exceeded( self, - mock_get_features, - mock_get_tenants, - mock_authenticate, - mock_get_invitation, - mock_is_rate_limit, - mock_db, - app, - mock_account, + mock_get_features: MagicMock, + mock_get_tenants: MagicMock, + mock_authenticate: MagicMock, + mock_get_invitation: MagicMock, + mock_is_rate_limit: MagicMock, + mock_db: MagicMock, + app: Flask, + mock_account: MagicMock, ): """ Test login failure when user has no workspace and workspace limit exceeded. @@ -367,7 +367,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login failure when invitation email doesn't match login email. @@ -491,7 +491,7 @@ class TestLogoutApi: @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") def test_successful_logout( - self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account + self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account ): """ Test successful logout flow. @@ -518,7 +518,7 @@ class TestLogoutApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.current_account_with_tenant") @patch("controllers.console.auth.login.flask_login") - def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app): + def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask): """ Test logout for anonymous (not logged in) user. diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index d010f60866..15c95f6b94 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -28,12 +28,12 @@ class TestRefreshTokenApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(RefreshTokenApi, "/refresh-token") return app.test_client() diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index 810f1b94fc..defa9064fd 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -49,7 +49,7 @@ class TestPartnerTenants: mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} - def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_success(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test successful partner tenants bindings sync.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -79,7 +79,7 @@ class TestPartnerTenants: mock_account.id, "partner-key-123", click_id ) - def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_invalid_partner_key_base64(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that invalid base64 partner_key raises BadRequest.""" # Arrange invalid_partner_key = "invalid-base64-!@#$" @@ -104,7 +104,7 @@ class TestPartnerTenants: resource.put(invalid_partner_key) assert "Invalid partner_key" in str(exc_info.value) - def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_missing_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that missing click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -128,7 +128,9 @@ class TestPartnerTenants: with pytest.raises(BadRequest): resource.put(partner_key_encoded) - def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_billing_service_json_decode_error( + self, app: Flask, mock_account, mock_billing_service, mock_decorators + ): """Test handling of billing service JSON decode error. When billing service returns non-200 status code with invalid JSON response, @@ -174,7 +176,7 @@ class TestPartnerTenants: assert isinstance(exc_info.value, json.JSONDecodeError) assert "Expecting value" in str(exc_info.value) - def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -199,7 +201,7 @@ class TestPartnerTenants: resource.put(partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_partner_key_after_decode(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty partner_key after decode raises BadRequest.""" # Arrange # Base64 encode an empty string @@ -225,7 +227,7 @@ class TestPartnerTenants: resource.put(empty_partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_user_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty user id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 032b1377a4..99a90f3b67 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -17,7 +17,7 @@ def app(): return app -def test_parse_openapi_to_tool_bundle_operation_id(app): +def test_parse_openapi_to_tool_bundle_operation_id(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -63,7 +63,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): assert tool_bundles[2].operation_id == "createResource" -def test_parse_openapi_to_tool_bundle_properties_all_of(app): +def test_parse_openapi_to_tool_bundle_properties_all_of(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -118,7 +118,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} -def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): +def test_parse_openapi_to_tool_bundle_default_value_type_casting(app: Flask): """ Test that default values are properly cast to match parameter types. This addresses the issue where array default values like [] cause validation errors diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py index 762d7b9090..e7f7cabecd 100644 --- a/api/tests/unit_tests/services/controller_api.py +++ b/api/tests/unit_tests/services/controller_api.py @@ -146,7 +146,7 @@ class ControllerApiTestDataFactory: return app @staticmethod - def create_api_instance(app): + def create_api_instance(app: Flask): """ Create a Flask-RESTX API instance. @@ -160,7 +160,12 @@ class ControllerApiTestDataFactory: return api @staticmethod - def create_test_client(app, api, resource_class, route): + def create_test_client( + app: Flask, + api: Api, + resource_class: type, + route: str, + ): """ Create a Flask test client with a resource registered. @@ -302,7 +307,7 @@ class TestDatasetListApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """ Create Flask-RESTX API instance. @@ -311,7 +316,7 @@ class TestDatasetListApi: return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """ Create test client with DatasetListApi registered. @@ -472,12 +477,12 @@ class TestDatasetApiGet: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with DatasetApi registered.""" return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets/") @@ -588,12 +593,12 @@ class TestDatasetApiCreate: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with DatasetApi registered.""" return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets") @@ -681,12 +686,12 @@ class TestHitTestingApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with HitTestingApi registered.""" return ControllerApiTestDataFactory.create_test_client( app, api, HitTestingApi, "/datasets//hit-testing" @@ -799,12 +804,12 @@ class TestExternalDatasetApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client_list(self, app, api): + def client_list(self, app: Flask, api: Api): """Create test client for external knowledge API list endpoint.""" return ControllerApiTestDataFactory.create_test_client( app, api, ExternalApiTemplateListApi, "/datasets/external-knowledge-api" diff --git a/api/uv.lock b/api/uv.lock index e42fc343b3..6f75c9f6fe 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -2657,14 +2657,14 @@ wheels = [ [[package]] name = "gitpython" -version = "3.1.47" +version = "3.1.49" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gitdb" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/bd/50db468e9b1310529a19fce651b3b0e753b5c07954d486cba31bbee9a5d5/gitpython-3.1.47.tar.gz", hash = "sha256:dba27f922bd2b42cb54c87a8ab3cb6beb6bf07f3d564e21ac848913a05a8a3cd", size = 216978, upload-time = "2026-04-22T02:44:44.059Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/63/210aaa302d6a0a78daa67c5c15bbac2cad361722841278b0209b6da20855/gitpython-3.1.49.tar.gz", hash = "sha256:42f9399c9eb33fc581014bedd76049dfbaf6375aa2a5754575966387280315e1", size = 219367, upload-time = "2026-04-29T00:31:20.478Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f2/c5/a1bc0996af85757903cf2bf444a7824e68e0035ce63fb41d6f76f9def68b/gitpython-3.1.47-py3-none-any.whl", hash = "sha256:489f590edfd6d20571b2c0e72c6a6ac6915ee8b8cd04572330e3842207a78905", size = 209547, upload-time = "2026-04-22T02:44:41.271Z" }, + { url = "https://files.pythonhosted.org/packages/fd/6f/b842bfa6f21d6f87c57f9abf7194225e55279d96d869775e19e9f7236fc5/gitpython-3.1.49-py3-none-any.whl", hash = "sha256:024b0422d7f84d15cd794844e029ffebd4c5d42a7eb9b936b458697ef550a02c", size = 212190, upload-time = "2026-04-29T00:31:18.412Z" }, ] [[package]] @@ -3740,14 +3740,14 @@ wheels = [ [[package]] name = "mako" -version = "1.3.11" +version = "1.3.12" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" } +sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" }, + { url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" }, ] [[package]] diff --git a/docker/.env.default b/docker/.env.default new file mode 100644 index 0000000000..6f6683b9f5 --- /dev/null +++ b/docker/.env.default @@ -0,0 +1,51 @@ +# ------------------------------------------------------------------ +# Minimal defaults for Docker Compose deployments. +# +# Keep local changes in .env. Use .env.example as the full reference +# for advanced and service-specific settings. +# ------------------------------------------------------------------ + +# Public URLs used when Dify generates links. Change these together when +# exposing Dify under another hostname, IP address, or port. +CONSOLE_WEB_URL=http://localhost +SERVICE_API_URL=http://localhost +APP_WEB_URL=http://localhost +FILES_URL=http://localhost +INTERNAL_FILES_URL=http://api:5001 +TRIGGER_URL=http://localhost +ENDPOINT_URL_TEMPLATE=http://localhost/e/{hook_id} +NEXT_PUBLIC_SOCKET_URL=ws://localhost +EXPOSE_PLUGIN_DEBUGGING_HOST=localhost +EXPOSE_PLUGIN_DEBUGGING_PORT=5003 + +# Built-in metadata database defaults. +DB_TYPE=postgresql +DB_USERNAME=postgres +DB_PASSWORD=difyai123456 +DB_HOST=db_postgres +DB_PORT=5432 +DB_DATABASE=dify + +# Built-in Redis defaults. +REDIS_HOST=redis +REDIS_PORT=6379 +REDIS_PASSWORD=difyai123456 + +# Default file storage. +STORAGE_TYPE=opendal +OPENDAL_SCHEME=fs +OPENDAL_FS_ROOT=storage + +# Default vector database. +VECTOR_STORE=weaviate + +# Internal service authentication. Paired values must match. +PLUGIN_DAEMON_KEY=lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi +PLUGIN_DIFY_INNER_API_KEY=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1 + +# Host ports. +EXPOSE_NGINX_PORT=80 +EXPOSE_NGINX_SSL_PORT=443 + +# Docker Compose profiles for bundled services. +COMPOSE_PROFILES=${VECTOR_STORE:-weaviate},${DB_TYPE:-postgresql} diff --git a/docker/.env.example b/docker/.env.example index 29741474fa..122228cdd1 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1003,7 +1003,7 @@ NOTION_INTERNAL_SECRET= # ------------------------------ # Mail type, support: resend, smtp, sendgrid -MAIL_TYPE=resend +MAIL_TYPE= # Default send from email address, if not specified # If using SendGrid, use the 'from' field for authentication if necessary. @@ -1011,7 +1011,7 @@ MAIL_DEFAULT_SEND_FROM= # API-Key for the Resend email provider, used when MAIL_TYPE is `resend`. RESEND_API_URL=https://api.resend.com -RESEND_API_KEY=your-resend-api-key +RESEND_API_KEY= # SMTP server configuration, used when MAIL_TYPE is `smtp` @@ -1359,10 +1359,10 @@ NGINX_ENABLE_CERTBOT_CHALLENGE=false # ------------------------------ # Email address (required to get certificates from Let's Encrypt) -CERTBOT_EMAIL=your_email@example.com +CERTBOT_EMAIL= # Domain name -CERTBOT_DOMAIN=your_domain.com +CERTBOT_DOMAIN= # certbot command options # i.e: --force-renewal --dry-run --test-cert --debug diff --git a/docker/README.md b/docker/README.md index 3130fa9886..3a7f4c2ad5 100644 --- a/docker/README.md +++ b/docker/README.md @@ -7,28 +7,28 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T - **Certbot Container**: `docker-compose.yaml` now contains `certbot` for managing SSL certificates. This container automatically renews certificates and ensures secure HTTPS connections.\ For more information, refer `docker/certbot/README.md`. -- **Persistent Environment Variables**: Environment variables are now managed through a `.env` file, ensuring that your configurations persist across deployments. +- **Persistent Environment Variables**: Default environment variables are managed through `.env.default`, while local overrides are stored in `.env`, ensuring that your configurations persist across deployments. > What is `.env`?

- > The `.env` file is a crucial component in Docker and Docker Compose environments, serving as a centralized configuration file where you can define environment variables that are accessible to the containers at runtime. This file simplifies the management of environment settings across different stages of development, testing, and production, providing consistency and ease of configuration to deployments. + > The `.env` file is a local override file. Keep it small by adding only the values that differ from `.env.default`. Use `.env.example` as the full reference when you need advanced configuration. - **Unified Vector Database Services**: All vector database services are now managed from a single Docker Compose file `docker-compose.yaml`. You can switch between different vector databases by setting the `VECTOR_STORE` environment variable in your `.env` file. -- **Mandatory .env File**: A `.env` file is now required to run `docker compose up`. This file is crucial for configuring your deployment and for any custom settings to persist through upgrades. +- **Local .env Overrides**: The `dify-compose` and `dify-compose.ps1` wrappers create `.env` if it is missing and generate a persistent `SECRET_KEY` for this deployment. ### How to Deploy Dify with `docker-compose.yaml` 1. **Prerequisites**: Ensure Docker and Docker Compose are installed on your system. 1. **Environment Setup**: - Navigate to the `docker` directory. - - Copy the `.env.example` file to a new file named `.env` by running `cp .env.example .env`. - - Customize the `.env` file as needed. Refer to the `.env.example` file for detailed configuration options. - - **Optional (Recommended for upgrades)**: - You may use the environment synchronization tool to help keep your `.env` file aligned with the latest `.env.example` updates, while preserving your custom settings. - This is especially useful when upgrading Dify or managing a large, customized `.env` file. + - No copy step is required. The `dify-compose` wrappers create `.env` if it is missing and write a generated `SECRET_KEY` to it. + - When prompted on first run, press Enter to use the default deployment, or answer `y` to stop and edit `.env` first. + - Customize `.env` only when you need to override defaults from `.env.default`. Refer to `.env.example` for the full list of available variables. + - **Optional (for advanced deployments)**: + If you maintain a full `.env` file copied from `.env.example`, you may use the environment synchronization tool to keep it aligned with the latest `.env.example` updates while preserving your custom settings. See the [Environment Variables Synchronization](#environment-variables-synchronization) section below. 1. **Running the Services**: - - Execute `docker compose up` from the `docker` directory to start the services. + - Execute `./dify-compose up -d` from the `docker` directory to start the services. On Windows PowerShell, run `.\dify-compose.ps1 up -d`. - To specify a vector database, set the `VECTOR_STORE` variable in your `.env` file to your desired vector database service, such as `milvus`, `weaviate`, or `opensearch`. 1. **SSL Certificate Setup**: - Refer `docker/certbot/README.md` to set up SSL certificates using Certbot. @@ -58,7 +58,13 @@ For users migrating from the `docker-legacy` setup: 1. **Data Migration**: - Ensure that data from services like databases and caches is backed up and migrated appropriately to the new structure if necessary. -### Overview of `.env` +### Overview of `.env.default`, `.env`, and `.env.example` + +- `.env.default` contains the minimal default configuration for Docker Compose deployments. +- `.env` contains the generated `SECRET_KEY` plus any local overrides. +- `.env.example` is the full reference for advanced configuration. + +The `dify-compose` wrappers merge `.env.default` and `.env` into a temporary environment file, append paired internal service keys when needed, and remove the temporary file after Docker Compose starts. #### Key Modules and Customization @@ -118,9 +124,11 @@ The `.env.example` file provided in the Docker setup is extensive and covers a w ### Environment Variables Synchronization -When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.example`. +When upgrading Dify or pulling the latest changes, new environment variables may be introduced in `.env.default` or `.env.example`. -To help keep your existing `.env` file up to date **without losing your custom values**, an optional environment variables synchronization tool is provided. +If you use the default override-only workflow, review `.env.default` and add only the values you need to override to `.env`. + +If you maintain a full `.env` file copied from `.env.example`, an optional environment variables synchronization tool is provided. > This tool performs a **one-way synchronization** from `.env.example` to `.env`. > Existing values in `.env` are never overwritten automatically. @@ -143,9 +151,9 @@ Before synchronization, the current `.env` file is saved to the `env-backup/` di **When to use** -- After upgrading Dify to a newer version +- After upgrading Dify to a newer version with a full `.env` file - When `.env.example` has been updated with new environment variables -- When managing a large or heavily customized `.env` file +- When managing a large or heavily customized `.env` file copied from `.env.example` **Usage** diff --git a/docker/dify-compose b/docker/dify-compose new file mode 100755 index 0000000000..16bbd6b538 --- /dev/null +++ b/docker/dify-compose @@ -0,0 +1,334 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +DEFAULT_ENV_FILE=".env.default" +USER_ENV_FILE=".env" + +log() { + printf '%s\n' "$*" >&2 +} + +die() { + printf 'Error: %s\n' "$*" >&2 + exit 1 +} + +detect_compose() { + if docker compose version >/dev/null 2>&1; then + COMPOSE_CMD=(docker compose) + return + fi + + if command -v docker-compose >/dev/null 2>&1; then + COMPOSE_CMD=(docker-compose) + return + fi + + die "Docker Compose is not available. Install Docker Compose, then run this command again." +} + +generate_secret_key() { + if command -v openssl >/dev/null 2>&1; then + openssl rand -base64 42 + return + fi + + if command -v dd >/dev/null 2>&1 && command -v base64 >/dev/null 2>&1; then + dd if=/dev/urandom bs=42 count=1 2>/dev/null | base64 | tr -d '\n' + printf '\n' + return + fi + + return 1 +} + +ensure_env_files() { + [[ -f "$DEFAULT_ENV_FILE" ]] || die "$DEFAULT_ENV_FILE is missing." + + if [[ -f "$USER_ENV_FILE" ]]; then + return + fi + + : >"$USER_ENV_FILE" + + if [[ ! -t 0 ]]; then + log "Created $USER_ENV_FILE for local overrides." + return + fi + + printf 'Created %s for local overrides.\n' "$USER_ENV_FILE" + printf 'Do you need a custom deployment now? (Most users can press Enter to skip.) [y/N] ' + read -r answer + + case "${answer:-}" in + y | Y | yes | YES | Yes) + cat <<'EOF' +Edit .env with the settings you want to override, using .env.example as the full reference. +Run ./dify-compose up -d again when you are ready. +EOF + exit 0 + ;; + esac +} + +user_env_value() { + local key="$1" + awk -F= -v target="$key" ' + /^[[:space:]]*#/ || !/=/{ next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + value = substr($0, index($0, "=") + 1) + gsub(/^[[:space:]]+|[[:space:]]+$/, "", value) + if ((value ~ /^".*"$/) || (value ~ /^'\''.*'\''$/)) { + value = substr(value, 2, length(value) - 2) + } + result = value + } + } + END { print result } + ' "$USER_ENV_FILE" +} + +set_user_env_value() { + local key="$1" + local value="$2" + local temp_file + + temp_file="$(mktemp "${TMPDIR:-/tmp}/dify-env.XXXXXX")" + awk -F= -v target="$key" -v replacement="$key=$value" ' + BEGIN { replaced = 0 } + /^[[:space:]]*#/ || !/=/{ print; next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + if (!replaced) { + print replacement + replaced = 1 + } + next + } + print + } + END { + if (!replaced) { + print replacement + } + } + ' "$USER_ENV_FILE" >"$temp_file" + mv "$temp_file" "$USER_ENV_FILE" +} + +ensure_secret_key() { + local current_secret_key + local secret_key + + current_secret_key="$(user_env_value SECRET_KEY)" + if [[ -n "$current_secret_key" ]]; then + return + fi + + secret_key="$(generate_secret_key)" || die "Unable to generate SECRET_KEY. Install openssl or configure SECRET_KEY in .env." + set_user_env_value SECRET_KEY "$secret_key" + log "Generated SECRET_KEY in $USER_ENV_FILE." +} + +env_value() { + local key="$1" + awk -F= -v target="$key" ' + /^[[:space:]]*#/ || !/=/{ next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + value = substr($0, index($0, "=") + 1) + gsub(/^[[:space:]]+|[[:space:]]+$/, "", value) + if ((value ~ /^".*"$/) || (value ~ /^'\''.*'\''$/)) { + value = substr(value, 2, length(value) - 2) + } + result = value + } + } + END { print result } + ' "$DEFAULT_ENV_FILE" "$USER_ENV_FILE" +} + +user_overrides() { + local key="$1" + grep -Eq "^[[:space:]]*${key}[[:space:]]*=" "$USER_ENV_FILE" +} + +write_merged_env() { + awk ' + function trim(s) { + sub(/^[[:space:]]+/, "", s) + sub(/[[:space:]]+$/, "", s) + return s + } + + /^[[:space:]]*#/ || !/=/{ next } + + { + key = $0 + sub(/=.*/, "", key) + key = trim(key) + if (key == "") { + next + } + + value = substr($0, index($0, "=") + 1) + value = trim(value) + + if (!(key in seen)) { + order[++count] = key + seen[key] = 1 + } + + values[key] = value + } + + END { + for (i = 1; i <= count; i++) { + key = order[i] + print key "=" values[key] + } + } + ' "$DEFAULT_ENV_FILE" "$USER_ENV_FILE" >"$MERGED_ENV_FILE" +} + +set_merged_env_value() { + local key="$1" + local value="$2" + local temp_file + + temp_file="$(mktemp "${TMPDIR:-/tmp}/dify-compose-env.XXXXXX")" + awk -F= -v target="$key" -v replacement="$key=$value" ' + BEGIN { replaced = 0 } + /^[[:space:]]*#/ || !/=/{ print; next } + { + key = $1 + gsub(/^[[:space:]]+|[[:space:]]+$/, "", key) + if (key == target) { + if (!replaced) { + print replacement + replaced = 1 + } + next + } + print + } + END { + if (!replaced) { + print replacement + } + } + ' "$MERGED_ENV_FILE" >"$temp_file" + mv "$temp_file" "$MERGED_ENV_FILE" +} + +set_if_not_overridden() { + local key="$1" + local value="$2" + + if user_overrides "$key"; then + return + fi + + set_merged_env_value "$key" "$value" +} + +metadata_db_host() { + case "$1" in + mysql) printf 'db_mysql' ;; + postgresql | '') printf 'db_postgres' ;; + *) printf '%s' "$(env_value DB_HOST)" ;; + esac +} + +metadata_db_port() { + case "$1" in + mysql) printf '3306' ;; + postgresql | '') printf '5432' ;; + *) printf '%s' "$(env_value DB_PORT)" ;; + esac +} + +metadata_db_user() { + case "$1" in + mysql) printf 'root' ;; + postgresql | '') printf 'postgres' ;; + *) printf '%s' "$(env_value DB_USERNAME)" ;; + esac +} + +build_merged_env() { + MERGED_ENV_FILE="$(mktemp "${TMPDIR:-/tmp}/dify-compose.XXXXXX")" + trap 'rm -f "$MERGED_ENV_FILE"' EXIT + + write_merged_env + + local db_type + local redis_host + local redis_port + local redis_username + local redis_password + local redis_auth + local code_execution_api_key + local weaviate_api_key + + db_type="$(env_value DB_TYPE)" + + set_if_not_overridden DB_HOST "$(metadata_db_host "$db_type")" + set_if_not_overridden DB_PORT "$(metadata_db_port "$db_type")" + set_if_not_overridden DB_USERNAME "$(metadata_db_user "$db_type")" + + if ! user_overrides CELERY_BROKER_URL; then + redis_host="$(env_value REDIS_HOST)" + redis_port="$(env_value REDIS_PORT)" + redis_username="$(env_value REDIS_USERNAME)" + redis_password="$(env_value REDIS_PASSWORD)" + redis_auth="" + + if [[ -n "$redis_username" && -n "$redis_password" ]]; then + redis_auth="${redis_username}:${redis_password}@" + elif [[ -n "$redis_password" ]]; then + redis_auth=":${redis_password}@" + elif [[ -n "$redis_username" ]]; then + redis_auth="${redis_username}@" + fi + + set_merged_env_value CELERY_BROKER_URL "redis://${redis_auth}${redis_host:-redis}:${redis_port:-6379}/1" + fi + + if ! user_overrides SANDBOX_API_KEY; then + code_execution_api_key="$(env_value CODE_EXECUTION_API_KEY)" + set_if_not_overridden SANDBOX_API_KEY "${code_execution_api_key:-dify-sandbox}" + fi + + if ! user_overrides WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS; then + weaviate_api_key="$(env_value WEAVIATE_API_KEY)" + set_if_not_overridden WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS \ + "${weaviate_api_key:-WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih}" + fi +} + +main() { + detect_compose + ensure_env_files + ensure_secret_key + build_merged_env + + if [[ "$#" -eq 0 ]]; then + set -- up -d + fi + + "${COMPOSE_CMD[@]}" --env-file "$MERGED_ENV_FILE" "$@" +} + +main "$@" diff --git a/docker/dify-compose.ps1 b/docker/dify-compose.ps1 new file mode 100644 index 0000000000..851f8b76fe --- /dev/null +++ b/docker/dify-compose.ps1 @@ -0,0 +1,317 @@ +$ErrorActionPreference = "Stop" +Set-StrictMode -Version Latest + +$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path +Set-Location $ScriptDir + +$DefaultEnvFile = ".env.default" +$UserEnvFile = ".env" +$MergedEnvFile = $null +$Utf8NoBom = New-Object System.Text.UTF8Encoding -ArgumentList $false + +function Write-Info { + param([string]$Message) + [Console]::Error.WriteLine($Message) +} + +function Fail { + param([string]$Message) + [Console]::Error.WriteLine("Error: $Message") + exit 1 +} + +function Test-CommandSuccess { + param([string[]]$Command) + + try { + $Executable = $Command[0] + $CommandArgs = @() + if ($Command.Length -gt 1) { + $CommandArgs = @($Command[1..($Command.Length - 1)]) + } + + & $Executable @CommandArgs *> $null + return $LASTEXITCODE -eq 0 + } + catch { + return $false + } +} + +function Get-ComposeCommand { + if (Test-CommandSuccess @("docker", "compose", "version")) { + return @("docker", "compose") + } + + if ((Get-Command "docker-compose" -ErrorAction SilentlyContinue) -and (Test-CommandSuccess @("docker-compose", "version"))) { + return @("docker-compose") + } + + Fail "Docker Compose is not available. Install Docker Compose, then run this command again." +} + +function New-SecretKey { + $Bytes = New-Object byte[] 42 + $Generator = [System.Security.Cryptography.RandomNumberGenerator]::Create() + + try { + $Generator.GetBytes($Bytes) + } + finally { + $Generator.Dispose() + } + + return [Convert]::ToBase64String($Bytes) +} + +function Ensure-EnvFiles { + if (-not (Test-Path $DefaultEnvFile -PathType Leaf)) { + Fail "$DefaultEnvFile is missing." + } + + if (Test-Path $UserEnvFile -PathType Leaf) { + return + } + + New-Item -ItemType File -Path $UserEnvFile | Out-Null + + if ([Console]::IsInputRedirected) { + Write-Info "Created $UserEnvFile for local overrides." + return + } + + Write-Info "Created $UserEnvFile for local overrides." + $Answer = Read-Host "Do you need a custom deployment now? (Most users can press Enter to skip.) [y/N]" + + if ($Answer -match "^(y|yes)$") { + Write-Output "Edit .env with the settings you want to override, using .env.example as the full reference." + Write-Output "Run .\dify-compose.ps1 up -d again when you are ready." + exit 0 + } +} + +function Read-EnvFile { + param([string]$Path) + + $Values = [ordered]@{} + + if (-not (Test-Path $Path -PathType Leaf)) { + return $Values + } + + foreach ($Line in Get-Content -Path $Path) { + if ($Line -match "^\s*#" -or $Line -notmatch "=") { + continue + } + + $SeparatorIndex = $Line.IndexOf("=") + $Key = $Line.Substring(0, $SeparatorIndex).Trim() + $Value = $Line.Substring($SeparatorIndex + 1).Trim() + + if (($Value.StartsWith('"') -and $Value.EndsWith('"')) -or ($Value.StartsWith("'") -and $Value.EndsWith("'"))) { + $Value = $Value.Substring(1, $Value.Length - 2) + } + + if ($Key.Length -gt 0) { + $Values[$Key] = $Value + } + } + + return $Values +} + +function Set-UserEnvValue { + param( + [string]$Key, + [string]$Value + ) + + $Path = [string](Resolve-Path $UserEnvFile) + $Lines = [System.IO.File]::ReadAllLines($Path, [System.Text.Encoding]::UTF8) + $Output = New-Object System.Collections.Generic.List[string] + $Replaced = $false + + foreach ($Line in $Lines) { + if ($Line -match "^\s*#" -or $Line -notmatch "=") { + $Output.Add($Line) + continue + } + + $SeparatorIndex = $Line.IndexOf("=") + $CurrentKey = $Line.Substring(0, $SeparatorIndex).Trim() + + if ($CurrentKey -eq $Key) { + if (-not $Replaced) { + $Output.Add("$Key=$Value") + $Replaced = $true + } + continue + } + + $Output.Add($Line) + } + + if (-not $Replaced) { + $Output.Add("$Key=$Value") + } + + [System.IO.File]::WriteAllLines($Path, $Output, $Utf8NoBom) +} + +function Ensure-SecretKey { + $Values = Read-EnvFile $UserEnvFile + + if ($Values.Contains("SECRET_KEY") -and $Values["SECRET_KEY"]) { + return + } + + Set-UserEnvValue "SECRET_KEY" (New-SecretKey) + Write-Info "Generated SECRET_KEY in $UserEnvFile." +} + +function Merge-EnvValues { + $Values = [ordered]@{} + + foreach ($Entry in (Read-EnvFile $DefaultEnvFile).GetEnumerator()) { + $Values[$Entry.Key] = $Entry.Value + } + + foreach ($Entry in (Read-EnvFile $UserEnvFile).GetEnumerator()) { + $Values[$Entry.Key] = $Entry.Value + } + + return $Values +} + +function User-Overrides { + param([string]$Key) + + if (-not (Test-Path $UserEnvFile -PathType Leaf)) { + return $false + } + + return [bool](Select-String -Path $UserEnvFile -Pattern "^\s*$([regex]::Escape($Key))\s*=" -Quiet) +} + +function Metadata-DbHost { + param([string]$DbType, $Values) + + switch ($DbType) { + "mysql" { return "db_mysql" } + "postgresql" { return "db_postgres" } + "" { return "db_postgres" } + default { return $Values["DB_HOST"] } + } +} + +function Metadata-DbPort { + param([string]$DbType, $Values) + + switch ($DbType) { + "mysql" { return "3306" } + "postgresql" { return "5432" } + "" { return "5432" } + default { return $Values["DB_PORT"] } + } +} + +function Metadata-DbUser { + param([string]$DbType, $Values) + + switch ($DbType) { + "mysql" { return "root" } + "postgresql" { return "postgres" } + "" { return "postgres" } + default { return $Values["DB_USERNAME"] } + } +} + +function Write-MergedEnv { + param($Values) + + $Output = New-Object System.Collections.Generic.List[string] + + foreach ($Entry in $Values.GetEnumerator()) { + $Output.Add("$($Entry.Key)=$($Entry.Value)") + } + + [System.IO.File]::WriteAllLines($MergedEnvFile, $Output, $Utf8NoBom) +} + +function Build-MergedEnv { + $Values = Merge-EnvValues + $script:MergedEnvFile = [System.IO.Path]::GetTempFileName() + + $DbType = if ($Values.Contains("DB_TYPE")) { $Values["DB_TYPE"] } else { "postgresql" } + + if (-not (User-Overrides "DB_HOST")) { + $Values["DB_HOST"] = Metadata-DbHost $DbType $Values + } + + if (-not (User-Overrides "DB_PORT")) { + $Values["DB_PORT"] = Metadata-DbPort $DbType $Values + } + + if (-not (User-Overrides "DB_USERNAME")) { + $Values["DB_USERNAME"] = Metadata-DbUser $DbType $Values + } + + if (-not (User-Overrides "CELERY_BROKER_URL")) { + $RedisHost = if ($Values.Contains("REDIS_HOST") -and $Values["REDIS_HOST"]) { $Values["REDIS_HOST"] } else { "redis" } + $RedisPort = if ($Values.Contains("REDIS_PORT") -and $Values["REDIS_PORT"]) { $Values["REDIS_PORT"] } else { "6379" } + $RedisUsername = if ($Values.Contains("REDIS_USERNAME")) { $Values["REDIS_USERNAME"] } else { "" } + $RedisPassword = if ($Values.Contains("REDIS_PASSWORD")) { $Values["REDIS_PASSWORD"] } else { "" } + $RedisAuth = "" + + if ($RedisUsername -and $RedisPassword) { + $RedisAuth = "${RedisUsername}:${RedisPassword}@" + } + elseif ($RedisPassword) { + $RedisAuth = ":${RedisPassword}@" + } + elseif ($RedisUsername) { + $RedisAuth = "${RedisUsername}@" + } + + $Values["CELERY_BROKER_URL"] = "redis://$RedisAuth${RedisHost}:${RedisPort}/1" + } + + if (-not (User-Overrides "SANDBOX_API_KEY")) { + $CodeExecutionApiKey = if ($Values.Contains("CODE_EXECUTION_API_KEY") -and $Values["CODE_EXECUTION_API_KEY"]) { $Values["CODE_EXECUTION_API_KEY"] } else { "dify-sandbox" } + $Values["SANDBOX_API_KEY"] = $CodeExecutionApiKey + } + + if (-not (User-Overrides "WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS")) { + $WeaviateApiKey = if ($Values.Contains("WEAVIATE_API_KEY") -and $Values["WEAVIATE_API_KEY"]) { $Values["WEAVIATE_API_KEY"] } else { "WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih" } + $Values["WEAVIATE_AUTHENTICATION_APIKEY_ALLOWED_KEYS"] = $WeaviateApiKey + } + + Write-MergedEnv $Values +} + +$ComposeCommand = Get-ComposeCommand + +try { + Ensure-EnvFiles + Ensure-SecretKey + Build-MergedEnv + + $ComposeArgs = @($args) + if ($ComposeArgs.Count -eq 0) { + $ComposeArgs = @("up", "-d") + } + + $ComposeCommandArgs = @() + if ($ComposeCommand.Length -gt 1) { + $ComposeCommandArgs = @($ComposeCommand[1..($ComposeCommand.Length - 1)]) + } + + $ComposeExecutable = $ComposeCommand[0] + & $ComposeExecutable @ComposeCommandArgs --env-file $MergedEnvFile @ComposeArgs + exit $LASTEXITCODE +} +finally { + if ($MergedEnvFile -and (Test-Path $MergedEnvFile -PathType Leaf)) { + Remove-Item -Force $MergedEnvFile + } +} diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 87fa01f671..b2df61ebb2 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -170,8 +170,8 @@ services: ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} - TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} @@ -402,8 +402,8 @@ services: - ./certbot/update-cert.template.txt:/update-cert.template.txt - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh environment: - - CERTBOT_EMAIL=${CERTBOT_EMAIL} - - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_EMAIL=${CERTBOT_EMAIL:-} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} entrypoint: ["/docker-entrypoint.sh"] command: ["tail", "-f", "/dev/null"] diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index a72136049d..6dcab4a9fc 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -441,10 +441,10 @@ x-shared-env: &shared-api-worker-env NOTION_CLIENT_SECRET: ${NOTION_CLIENT_SECRET:-} NOTION_CLIENT_ID: ${NOTION_CLIENT_ID:-} NOTION_INTERNAL_SECRET: ${NOTION_INTERNAL_SECRET:-} - MAIL_TYPE: ${MAIL_TYPE:-resend} + MAIL_TYPE: ${MAIL_TYPE:-} MAIL_DEFAULT_SEND_FROM: ${MAIL_DEFAULT_SEND_FROM:-} RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} - RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} + RESEND_API_KEY: ${RESEND_API_KEY:-} SMTP_SERVER: ${SMTP_SERVER:-} SMTP_PORT: ${SMTP_PORT:-465} SMTP_USERNAME: ${SMTP_USERNAME:-} @@ -586,8 +586,8 @@ x-shared-env: &shared-api-worker-env NGINX_PROXY_READ_TIMEOUT: ${NGINX_PROXY_READ_TIMEOUT:-3600s} NGINX_PROXY_SEND_TIMEOUT: ${NGINX_PROXY_SEND_TIMEOUT:-3600s} NGINX_ENABLE_CERTBOT_CHALLENGE: ${NGINX_ENABLE_CERTBOT_CHALLENGE:-false} - CERTBOT_EMAIL: ${CERTBOT_EMAIL:-your_email@example.com} - CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-your_domain.com} + CERTBOT_EMAIL: ${CERTBOT_EMAIL:-} + CERTBOT_DOMAIN: ${CERTBOT_DOMAIN:-} CERTBOT_OPTIONS: ${CERTBOT_OPTIONS:-} SSRF_HTTP_PORT: ${SSRF_HTTP_PORT:-3128} SSRF_COREDUMP_DIR: ${SSRF_COREDUMP_DIR:-/var/spool/squid} @@ -894,8 +894,8 @@ services: ALLOW_UNSAFE_DATA_SCHEME: ${ALLOW_UNSAFE_DATA_SCHEME:-false} MARKETPLACE_API_URL: ${MARKETPLACE_API_URL:-https://marketplace.dify.ai} MARKETPLACE_URL: ${MARKETPLACE_URL:-https://marketplace.dify.ai} - TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-} - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-} + TOP_K_MAX_VALUE: ${TOP_K_MAX_VALUE:-10} + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} LOOP_NODE_MAX_COUNT: ${LOOP_NODE_MAX_COUNT:-100} MAX_TOOLS_NUM: ${MAX_TOOLS_NUM:-10} MAX_PARALLEL_LIMIT: ${MAX_PARALLEL_LIMIT:-10} @@ -1126,8 +1126,8 @@ services: - ./certbot/update-cert.template.txt:/update-cert.template.txt - ./certbot/docker-entrypoint.sh:/docker-entrypoint.sh environment: - - CERTBOT_EMAIL=${CERTBOT_EMAIL} - - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} + - CERTBOT_EMAIL=${CERTBOT_EMAIL:-} + - CERTBOT_DOMAIN=${CERTBOT_DOMAIN:-} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} entrypoint: ["/docker-entrypoint.sh"] command: ["tail", "-f", "/dev/null"] diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 2147bb95e8..b4876dcf45 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -4194,11 +4194,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-value-method.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/index.tsx": { "no-restricted-imports": { "count": 1 diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index fe4c10329e..4a4742adcf 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -4156,8 +4156,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesResponse export type DeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdData = { body?: never path: { - app_id: string variable_id: string + app_id: string } query?: never url: '/apps/{app_id}/workflows/draft/variables/{variable_id}' @@ -4210,8 +4210,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesByVariableIdResponse export type PatchAppsByAppIdWorkflowsDraftVariablesByVariableIdData = { body: WorkflowDraftVariableUpdatePayload path: { - app_id: string variable_id: string + app_id: string } query?: never url: '/apps/{app_id}/workflows/draft/variables/{variable_id}' diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index dcaeaed246..9798d22cc0 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -2980,8 +2980,8 @@ export const zGetAppsByAppIdWorkflowsDraftVariablesQuery = z.object({ export const zGetAppsByAppIdWorkflowsDraftVariablesResponse = zWorkflowDraftVariableListWithoutValue export const zDeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({ - app_id: z.string(), variable_id: z.string(), + app_id: z.string(), }) /** @@ -3006,8 +3006,8 @@ export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdBody = zWorkflowDraftVariableUpdatePayload export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({ - app_id: z.string(), variable_id: z.string(), + app_id: z.string(), }) /** diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 89a68593b7..61d380d686 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -255,6 +255,7 @@ export type ProcessRule = { } export type RetrievalModel = { + metadata_filtering_conditions?: MetadataFilteringCondition reranking_enable: boolean reranking_mode?: string | null reranking_model?: RerankingModel @@ -312,6 +313,11 @@ export type Rule = { subchunk_segmentation?: Segmentation } +export type MetadataFilteringCondition = { + conditions?: Array | null + logical_operator?: 'and' | 'or' | null +} + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null @@ -405,6 +411,30 @@ export type Segmentation = { separator?: string } +export type Condition = { + comparison_operator: + | 'contains' + | 'not contains' + | 'start with' + | 'end with' + | 'is' + | 'is not' + | 'empty' + | 'not empty' + | 'in' + | 'not in' + | '=' + | '≠' + | '>' + | '<' + | '≥' + | '≤' + | 'before' + | 'after' + name: string + value?: unknown +} + export type WeightKeywordSetting = { keyword_weight: number } @@ -1174,8 +1204,8 @@ export type PatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse export type DeleteDatasetsByDatasetIdDocumentsByDocumentIdData = { body?: never path: { - dataset_id: string document_id: string + dataset_id: string } query?: never url: '/datasets/{dataset_id}/documents/{document_id}' diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index 2ac2cbfd1f..76491c52a0 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -392,6 +392,46 @@ export const zProcessRule = z.object({ rules: zRule.optional(), }) +/** + * Condition + * + * Condition detail + */ +export const zCondition = z.object({ + comparison_operator: z.enum([ + 'contains', + 'not contains', + 'start with', + 'end with', + 'is', + 'is not', + 'empty', + 'not empty', + 'in', + 'not in', + '=', + '≠', + '>', + '<', + '≥', + '≤', + 'before', + 'after', + ]), + name: z.string(), + value: z.unknown().optional(), +}) + +/** + * MetadataFilteringCondition + * + * Metadata Filtering Condition. + */ +export const zMetadataFilteringCondition = z.object({ + conditions: z.array(zCondition).nullish(), + logical_operator: z.enum(['and', 'or']).nullish().default('and'), +}) + /** * WeightKeywordSetting */ @@ -421,6 +461,7 @@ export const zWeightModel = z.object({ * RetrievalModel */ export const zRetrievalModel = z.object({ + metadata_filtering_conditions: zMetadataFilteringCondition.optional(), reranking_enable: z.boolean(), reranking_mode: z.string().nullish(), reranking_model: zRerankingModel.optional(), @@ -925,8 +966,8 @@ export const zPatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse = z.r ) export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdPath = z.object({ - dataset_id: z.string(), document_id: z.string(), + dataset_id: z.string(), }) /** diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index f491c1e3f9..e3791e295c 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -325,8 +325,37 @@ export type WorkflowRunResponse = { workflow_id: string } +export type Condition = { + comparison_operator: + | 'contains' + | 'not contains' + | 'start with' + | 'end with' + | 'is' + | 'is not' + | 'empty' + | 'not empty' + | 'in' + | 'not in' + | '=' + | '≠' + | '>' + | '<' + | '≥' + | '≤' + | 'before' + | 'after' + name: string + value?: unknown +} + export type DatasetPermissionEnum = 'only_me' | 'all_team_members' | 'partial_members' +export type MetadataFilteringCondition = { + conditions?: Array | null + logical_operator?: 'and' | 'or' | null +} + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null @@ -339,6 +368,7 @@ export type RetrievalMethod | 'keyword_search' export type RetrievalModel = { + metadata_filtering_conditions?: MetadataFilteringCondition reranking_enable: boolean reranking_mode?: string | null reranking_model?: RerankingModel @@ -1833,8 +1863,8 @@ export type GetDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdData = body?: never path: { segment_id: string - dataset_id: string document_id: string + dataset_id: string } query?: never url: '/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 2c2400c0cb..6feacbdead 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -326,11 +326,51 @@ export const zWorkflowRunResponse = z.object({ workflow_id: z.string(), }) +/** + * Condition + * + * Condition detail + */ +export const zCondition = z.object({ + comparison_operator: z.enum([ + 'contains', + 'not contains', + 'start with', + 'end with', + 'is', + 'is not', + 'empty', + 'not empty', + 'in', + 'not in', + '=', + '≠', + '>', + '<', + '≥', + '≤', + 'before', + 'after', + ]), + name: z.string(), + value: z.unknown().optional(), +}) + /** * DatasetPermissionEnum */ export const zDatasetPermissionEnum = z.enum(['only_me', 'all_team_members', 'partial_members']) +/** + * MetadataFilteringCondition + * + * Metadata Filtering Condition. + */ +export const zMetadataFilteringCondition = z.object({ + conditions: z.array(zCondition).nullish(), + logical_operator: z.enum(['and', 'or']).nullish().default('and'), +}) + /** * RerankingModel */ @@ -378,6 +418,7 @@ export const zWeightModel = z.object({ * RetrievalModel */ export const zRetrievalModel = z.object({ + metadata_filtering_conditions: zMetadataFilteringCondition.optional(), reranking_enable: z.boolean(), reranking_mode: z.string().nullish(), reranking_model: zRerankingModel.optional(), @@ -1082,8 +1123,8 @@ export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdR export const zGetDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdPath = z.object({ segment_id: z.string(), - dataset_id: z.string(), document_id: z.string(), + dataset_id: z.string(), }) /** diff --git a/packages/dev-proxy/README.md b/packages/dev-proxy/README.md new file mode 100644 index 0000000000..6b9d7298c4 --- /dev/null +++ b/packages/dev-proxy/README.md @@ -0,0 +1,196 @@ +# @langgenius/dev-proxy + +Generic Hono-based development proxy for frontend projects. The package does not ship any product-specific routes, cookie names, or environment variable conventions. Every proxied path and upstream target is declared in a local config file. + +## Installation + +```bash +pnpm add -D @langgenius/dev-proxy +``` + +Add a script in your frontend project: + +```json +{ + "scripts": { + "dev:proxy": "dev-proxy --config ./dev-proxy.config.ts --env-file ./.env" + } +} +``` + +Run it with: + +```bash +pnpm dev:proxy +``` + +## CLI + +```bash +dev-proxy --config ./dev-proxy.config.ts +``` + +Supported options: + +- `--config`, `-c`: config file path. Defaults to `dev-proxy.config.ts`. +- `--env-file`: load environment variables before evaluating the config file. +- `--host`: override `server.host` from config. +- `--port`: override `server.port` from config. +- `--help`, `-h`: print help. + +`--target` is not supported. Put targets in the config file so routes and upstreams stay explicit. + +## Config Shape + +```ts +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +export default defineDevProxyConfig({ + server: { + host: '127.0.0.1', + port: 5001, + }, + routes: [ + { + paths: '/api', + target: 'https://example.com', + }, + ], + cors: { + allowedOrigins: 'local', + }, +}) +``` + +Config files can be `.ts`, `.mts`, `.js`, or `.mjs`. + +`routes` are matched in declaration order. The first matching route wins. Each configured path matches both the exact path and all child paths, so `paths: '/api'` matches `/api`, `/api/apps`, and `/api/apps/123`. + +By default, credentialed CORS is allowed for local development origins such as `localhost`, `127.0.0.1`, and `::1`. To restrict it to specific origins: + +``` +cors: { + allowedOrigins: ['http://localhost:3000'], +} +``` + +## Scenario 1: Proxy One Local Route Group To An Online Backend + +Use this when a local frontend should call an online backend through one proxy server. For example, the frontend calls `http://127.0.0.1:5001/api/apps`, and the proxy forwards it to `https://cloud.example.com/api/apps`. + +```ts +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +const target = process.env.DEV_PROXY_TARGET || 'https://cloud.example.com' + +export default defineDevProxyConfig({ + server: { + host: process.env.DEV_PROXY_HOST || '127.0.0.1', + port: Number(process.env.DEV_PROXY_PORT || 5001), + }, + routes: [ + { + paths: '/api', + target, + }, + ], +}) +``` + +Optional `.env`: + +```env +DEV_PROXY_TARGET=https://cloud.example.com +DEV_PROXY_HOST=127.0.0.1 +DEV_PROXY_PORT=5001 +``` + +Command: + +```bash +dev-proxy --config ./dev-proxy.config.ts --env-file ./.env +``` + +## Scenario 2: Proxy Two Route Groups To Two Local Backends + +Use this when one frontend needs to talk to two different local services. For example: + +- `/console/api/*` goes to a local console backend at `http://127.0.0.1:5001` +- `/api/*` goes to a local public API backend at `http://127.0.0.1:5002` + +```ts +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +const consoleApiTarget = process.env.DEV_PROXY_CONSOLE_API_TARGET || 'http://127.0.0.1:5001' +const publicApiTarget = process.env.DEV_PROXY_PUBLIC_API_TARGET || 'http://127.0.0.1:5002' + +export default defineDevProxyConfig({ + server: { + host: process.env.DEV_PROXY_HOST || '127.0.0.1', + port: Number(process.env.DEV_PROXY_PORT || 8082), + }, + routes: [ + { + paths: '/console/api', + target: consoleApiTarget, + }, + { + paths: '/api', + target: publicApiTarget, + }, + ], +}) +``` + +Optional `.env`: + +```env +DEV_PROXY_CONSOLE_API_TARGET=http://127.0.0.1:5001 +DEV_PROXY_PUBLIC_API_TARGET=http://127.0.0.1:5002 +DEV_PROXY_HOST=127.0.0.1 +DEV_PROXY_PORT=8082 +``` + +When two route groups overlap, put the more specific one first: + +```ts +routes: [ + { paths: '/api/enterprise', target: 'http://127.0.0.1:5003' }, + { paths: '/api', target: 'http://127.0.0.1:5002' }, +] +``` + +## Cookie Rewrite + +Cookie rewriting is opt-in and config-driven. The package does not know any application cookie names. + +Use `cookieRewrite` when an upstream uses secure cookie prefixes such as `__Host-` or `__Secure-`, but local development needs cookies to work over `http://localhost`. + +```ts +import type { CookieRewriteOptions } from '@langgenius/dev-proxy' +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +const cookieRewrite: CookieRewriteOptions = { + hostPrefixCookies: ['access_token', 'refresh_token', /^passport-/], +} + +export default defineDevProxyConfig({ + routes: [ + { + paths: '/api', + target: 'https://cloud.example.com', + cookieRewrite, + }, + ], +}) +``` + +Set `cookieRewrite: false` to disable cookie rewriting for a route. + +## Behavior + +- The proxy preserves the matched path prefix when forwarding requests. +- Request bodies are forwarded as streams. +- Hop-by-hop headers are removed before forwarding. +- Local credentialed CORS and preflight requests are handled by the proxy. +- Route matching is explicit and order-sensitive. diff --git a/packages/dev-proxy/bin/dev-proxy.js b/packages/dev-proxy/bin/dev-proxy.js new file mode 100755 index 0000000000..02e37f3525 --- /dev/null +++ b/packages/dev-proxy/bin/dev-proxy.js @@ -0,0 +1,3 @@ +#!/usr/bin/env node + +import '../dist/cli.mjs' diff --git a/packages/dev-proxy/package.json b/packages/dev-proxy/package.json new file mode 100644 index 0000000000..d5524290eb --- /dev/null +++ b/packages/dev-proxy/package.json @@ -0,0 +1,43 @@ +{ + "name": "@langgenius/dev-proxy", + "type": "module", + "version": "0.0.5", + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs" + } + }, + "types": "./dist/index.d.mts", + "bin": { + "dev-proxy": "./bin/dev-proxy.js" + }, + "files": [ + "bin", + "dist", + "src" + ], + "engines": { + "node": "^22.22.1" + }, + "scripts": { + "build": "vp pack", + "prepare": "pnpm run build", + "test": "vp test", + "type-check": "tsgo", + "prepublish": "pnpm run build" + }, + "dependencies": { + "@hono/node-server": "catalog:", + "c12": "catalog:", + "hono": "catalog:" + }, + "devDependencies": { + "@dify/tsconfig": "workspace:*", + "@types/node": "catalog:", + "@typescript/native-preview": "catalog:", + "vite": "catalog:", + "vite-plus": "catalog:", + "vitest": "catalog:" + } +} diff --git a/packages/dev-proxy/src/cli.spec.ts b/packages/dev-proxy/src/cli.spec.ts new file mode 100644 index 0000000000..e8a87a0588 --- /dev/null +++ b/packages/dev-proxy/src/cli.spec.ts @@ -0,0 +1,158 @@ +/** + * @vitest-environment node + */ +import type { ChildProcessByStdio } from 'node:child_process' +import type { Readable } from 'node:stream' +import { spawn } from 'node:child_process' +import { once } from 'node:events' +import fs from 'node:fs/promises' +import net from 'node:net' +import os from 'node:os' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { afterEach, describe, expect, it } from 'vitest' + +const tempDirs: string[] = [] +type DevProxyCliProcess = ChildProcessByStdio + +const childProcesses: DevProxyCliProcess[] = [] +const binPath = fileURLToPath(new URL('../bin/dev-proxy.js', import.meta.url)) + +const createTempDir = async () => { + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'dev-proxy-cli-test-')) + tempDirs.push(tempDir) + return tempDir +} + +const getFreePort = async () => { + const server = net.createServer() + await new Promise((resolve, reject) => { + server.once('error', reject) + server.listen(0, '127.0.0.1', resolve) + }) + + const address = server.address() + if (!address || typeof address === 'string') + throw new Error('Failed to allocate a test port.') + + const { port } = address + await new Promise((resolve, reject) => { + server.close((error) => { + if (error) + reject(error) + else + resolve() + }) + }) + + return port +} + +const waitForOutput = ( + child: DevProxyCliProcess, + output: () => string, + expectedOutput: string, +) => new Promise((resolve, reject) => { + let timeout: ReturnType + + function cleanup() { + clearTimeout(timeout) + child.stdout.off('data', onData) + child.stderr.off('data', onData) + child.off('exit', onExit) + } + + function onData() { + if (!output().includes(expectedOutput)) + return + + cleanup() + resolve() + } + + function onExit(code: number | null, signal: NodeJS.Signals | null) { + cleanup() + reject(new Error(`dev-proxy exited before writing "${expectedOutput}" with code ${code} and signal ${signal}. Output:\n${output()}`)) + } + + timeout = setTimeout(() => { + cleanup() + reject(new Error(`Timed out waiting for "${expectedOutput}". Output:\n${output()}`)) + }, 3000) + + child.stdout.on('data', onData) + child.stderr.on('data', onData) + child.once('exit', onExit) + onData() +}) + +const spawnCli = (args: readonly string[], cwd: string) => { + const child = spawn(process.execPath, [binPath, ...args], { + cwd, + env: { + ...process.env, + FORCE_COLOR: '0', + }, + stdio: ['ignore', 'pipe', 'pipe'], + }) + childProcesses.push(child) + return child +} + +const stopChildProcess = async (child: DevProxyCliProcess) => { + if (child.exitCode !== null || child.signalCode !== null) + return + + child.kill('SIGTERM') + await once(child, 'exit') +} + +describe('dev proxy CLI', () => { + afterEach(async () => { + await Promise.all(childProcesses.splice(0).map(stopChildProcess)) + await Promise.all(tempDirs.splice(0).map(tempDir => fs.rm(tempDir, { + force: true, + recursive: true, + }))) + }) + + // Scenario: help output should still be a normal short-lived command. + it('should print help and exit', async () => { + // Arrange + const tempDir = await createTempDir() + const child = spawnCli(['--help'], tempDir) + + // Act + const [code] = await once(child, 'exit') + + // Assert + expect(code).toBe(0) + }) + + // Scenario: successful server startup should keep the CLI process alive. + it('should keep running after starting the proxy server', async () => { + // Arrange + const tempDir = await createTempDir() + const port = await getFreePort() + await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), ` + export default { + routes: [{ paths: '/api', target: 'https://api.example.com' }], + } + `) + + let output = '' + const child = spawnCli(['--config', './dev-proxy.config.ts', '--host', '127.0.0.1', '--port', String(port)], tempDir) + child.stdout.on('data', chunk => output += chunk.toString()) + child.stderr.on('data', chunk => output += chunk.toString()) + + // Act + await waitForOutput(child, () => output, `[dev-proxy] listening on http://127.0.0.1:${port}`) + await new Promise(resolve => setTimeout(resolve, 100)) + const response = await fetch(`http://127.0.0.1:${port}/not-proxied`) + + // Assert + expect(child.exitCode).toBeNull() + expect(child.signalCode).toBeNull() + expect(response.status).toBe(404) + }) +}) diff --git a/packages/dev-proxy/src/cli.ts b/packages/dev-proxy/src/cli.ts new file mode 100644 index 0000000000..05234cb359 --- /dev/null +++ b/packages/dev-proxy/src/cli.ts @@ -0,0 +1,56 @@ +import process from 'node:process' +import { serve } from '@hono/node-server' +import { loadDevProxyConfig, parseDevProxyCliArgs, resolveDevProxyServerOptions } from './config' +import { createDevProxyApp } from './server' + +function printUsage() { + console.log(`Usage: + dev-proxy --config [options] + +Options: + --config, -c Path to a dev proxy config file. Defaults to dev-proxy.config.ts. + --env-file Load environment variables before evaluating the config file. + --host Override the configured host. + --port Override the configured port. + --help, -h Show this help message.`) +} + +async function flushStandardStreams() { + await Promise.all([ + new Promise(resolve => process.stdout.write('', () => resolve())), + new Promise(resolve => process.stderr.write('', () => resolve())), + ]) +} + +async function main() { + const cliOptions = parseDevProxyCliArgs(process.argv.slice(2)) + + if (cliOptions.help) { + printUsage() + return + } + + const config = await loadDevProxyConfig(cliOptions.config, process.cwd(), { + envFile: cliOptions.envFile, + }) + const { host, port } = resolveDevProxyServerOptions(config.server, cliOptions) + const app = createDevProxyApp(config) + + serve({ + fetch: app.fetch, + hostname: host, + port, + }) + + console.log(`[dev-proxy] listening on http://${host}:${port}`) +} + +try { + await main() + await flushStandardStreams() +} +catch (error) { + console.error(error instanceof Error ? error.message : error) + await flushStandardStreams() + process.exit(1) +} diff --git a/packages/dev-proxy/src/config.spec.ts b/packages/dev-proxy/src/config.spec.ts new file mode 100644 index 0000000000..6f681bcbae --- /dev/null +++ b/packages/dev-proxy/src/config.spec.ts @@ -0,0 +1,145 @@ +/** + * @vitest-environment node + */ +import fs from 'node:fs/promises' +import os from 'node:os' +import path from 'node:path' +import { afterEach, describe, expect, it } from 'vitest' +import { loadDevProxyConfig, parseDevProxyCliArgs, resolveDevProxyServerOptions } from './config' + +const tempDirs: string[] = [] + +const createTempDir = async () => { + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'dev-proxy-test-')) + tempDirs.push(tempDir) + return tempDir +} + +describe('dev proxy config', () => { + afterEach(async () => { + delete process.env.DEV_PROXY_TEST_PORT + delete process.env.DEV_PROXY_TEST_TARGET + + await Promise.all(tempDirs.splice(0).map(tempDir => fs.rm(tempDir, { + force: true, + recursive: true, + }))) + }) + + // Scenario: CLI options should support both inline and separated values. + it('should parse proxy CLI options', () => { + // Act + const options = parseDevProxyCliArgs([ + '--config=./dev-proxy.config.ts', + '--env-file', + './.env.proxy', + '--host', + '0.0.0.0', + '--port', + '8083', + ]) + + // Assert + expect(options).toEqual({ + config: './dev-proxy.config.ts', + envFile: './.env.proxy', + host: '0.0.0.0', + port: '8083', + }) + }) + + // Scenario: removed target shortcuts should fail instead of silently doing the wrong thing. + it('should reject unsupported target shortcuts', () => { + // Assert + expect(() => parseDevProxyCliArgs(['--target', 'enterprise'])).toThrow('Unsupported dev proxy option') + }) + + // Scenario: package manager argument separators should not be treated as proxy options. + it('should ignore package manager argument separators', () => { + // Act + const options = parseDevProxyCliArgs(['--config', './dev-proxy.config.ts', '--', '--help']) + + // Assert + expect(options).toEqual({ + config: './dev-proxy.config.ts', + help: true, + }) + }) + + // Scenario: CLI host and port should override config defaults. + it('should resolve server options with CLI overrides', () => { + // Act + const options = resolveDevProxyServerOptions({ + host: '127.0.0.1', + port: 5001, + }, { + host: '0.0.0.0', + port: '9002', + }) + + // Assert + expect(options).toEqual({ + host: '0.0.0.0', + port: 9002, + }) + }) + + // Scenario: TS config files should load through c12. + it('should load a TypeScript config file', async () => { + // Arrange + const tempDir = await createTempDir() + await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), ` + export default { + server: { host: '127.0.0.1', port: 7777 }, + routes: [{ paths: ['/api', '/files'], target: 'https://api.example.com' }], + } + `) + + // Act + const config = await loadDevProxyConfig('dev-proxy.config.ts', tempDir) + + // Assert + expect(config.server).toEqual({ + host: '127.0.0.1', + port: 7777, + }) + expect(config.routes).toEqual([ + { + paths: ['/api', '/files'], + target: 'https://api.example.com', + }, + ]) + }) + + // Scenario: env files should be loaded before the TypeScript config is evaluated. + it('should load a TypeScript config file with env file values', async () => { + // Arrange + const tempDir = await createTempDir() + await fs.writeFile(path.join(tempDir, '.env.proxy'), [ + 'DEV_PROXY_TEST_PORT=7788', + 'DEV_PROXY_TEST_TARGET=https://env.example.com', + ].join('\n')) + await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), ` + export default { + server: { port: Number(process.env.DEV_PROXY_TEST_PORT) }, + routes: [{ paths: '/api', target: process.env.DEV_PROXY_TEST_TARGET }], + } + `) + + // Act + const config = await loadDevProxyConfig('dev-proxy.config.ts', tempDir, { + envFile: '.env.proxy', + }) + + // Assert + expect(config.server).toEqual({ + port: 7788, + }) + expect(config.routes).toEqual([ + { + paths: '/api', + target: 'https://env.example.com', + }, + ]) + }) +}) diff --git a/packages/dev-proxy/src/config.ts b/packages/dev-proxy/src/config.ts new file mode 100644 index 0000000000..b23cb0a152 --- /dev/null +++ b/packages/dev-proxy/src/config.ts @@ -0,0 +1,129 @@ +import type { DotenvOptions } from 'c12' +import type { DevProxyCliOptions, DevProxyConfig, DevProxyConfigLoadOptions, DevProxyServerConfig, ResolvedDevProxyServerOptions } from './types' +import path from 'node:path' +import { loadConfig } from 'c12' + +const DEFAULT_CONFIG_FILE = 'dev-proxy.config.ts' +const DEFAULT_PROXY_HOST = '127.0.0.1' +const DEFAULT_PROXY_PORT = 5001 + +const OPTION_NAME_TO_KEY = { + '--config': 'config', + '-c': 'config', + '--env-file': 'envFile', + '--host': 'host', + '--port': 'port', +} as const + +type OptionName = keyof typeof OPTION_NAME_TO_KEY + +const isOptionName = (value: string): value is OptionName => value in OPTION_NAME_TO_KEY + +const requireOptionValue = (name: string, value?: string) => { + if (!value || value.startsWith('-')) + throw new Error(`Missing value for ${name}.`) + + return value +} + +export const parseDevProxyCliArgs = (argv: readonly string[]): DevProxyCliOptions => { + const options: DevProxyCliOptions = {} + + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index]! + + if (arg === '--') + continue + + if (arg === '--help' || arg === '-h') { + options.help = true + continue + } + + const [rawName, inlineValue] = arg.split('=', 2) + const name = rawName ?? '' + + if (!name.startsWith('-')) + continue + + if (!isOptionName(name)) + throw new Error(`Unsupported dev proxy option "${name}".`) + + const key = OPTION_NAME_TO_KEY[name] + options[key] = inlineValue ?? requireOptionValue(name, argv[index + 1]) + + if (inlineValue === undefined) + index += 1 + } + + return options +} + +const resolvePort = (rawPort: string | number) => { + const port = Number(rawPort) + if (!Number.isInteger(port) || port < 1 || port > 65535) + throw new Error(`Invalid proxy port "${rawPort}". Expected an integer between 1 and 65535.`) + + return port +} + +export const resolveDevProxyServerOptions = ( + serverConfig: DevProxyServerConfig = {}, + cliOptions: DevProxyCliOptions = {}, +): ResolvedDevProxyServerOptions => { + const configuredPort = cliOptions.port ?? serverConfig.port ?? DEFAULT_PROXY_PORT + + return { + host: cliOptions.host || serverConfig.host || DEFAULT_PROXY_HOST, + port: resolvePort(configuredPort), + } +} + +const isRecord = (value: unknown): value is Record => + typeof value === 'object' && value !== null + +export function assertDevProxyConfig(config: unknown): asserts config is DevProxyConfig { + if (!isRecord(config)) + throw new Error('Dev proxy config must export an object.') + + if (!Array.isArray(config.routes)) + throw new Error('Dev proxy config must include a routes array.') +} + +const resolveDotenvOptions = ( + envFile: DevProxyConfigLoadOptions['envFile'], + cwd: string, +): DotenvOptions | false => { + if (!envFile) + return false + + const resolvedEnvFilePath = path.resolve(cwd, envFile) + return { + cwd: path.dirname(resolvedEnvFilePath), + fileName: path.basename(resolvedEnvFilePath), + interpolate: true, + } +} + +export const loadDevProxyConfig = async ( + configPath = DEFAULT_CONFIG_FILE, + cwd = process.cwd(), + options: DevProxyConfigLoadOptions = {}, +): Promise => { + const resolvedConfigPath = path.resolve(cwd, configPath) + const parsedPath = path.parse(resolvedConfigPath) + const { config: loadedConfig } = await loadConfig({ + configFile: parsedPath.name, + cwd: parsedPath.dir, + dotenv: resolveDotenvOptions(options.envFile, cwd), + envName: false, + globalRc: false, + packageJson: false, + rcFile: false, + }) + + assertDevProxyConfig(loadedConfig) + return loadedConfig +} + +export const defineDevProxyConfig = (config: DevProxyConfig) => config diff --git a/packages/dev-proxy/src/cookies.spec.ts b/packages/dev-proxy/src/cookies.spec.ts new file mode 100644 index 0000000000..4a1b614eeb --- /dev/null +++ b/packages/dev-proxy/src/cookies.spec.ts @@ -0,0 +1,44 @@ +/** + * @vitest-environment node + */ +import { describe, expect, it } from 'vitest' +import { rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies' + +describe('dev proxy cookies', () => { + // Scenario: cookie names should only receive secure host prefixes when configured. + it('should rewrite configured cookie names for HTTPS upstream requests', () => { + // Act + const cookieHeader = rewriteCookieHeaderForUpstream('access_token=abc; theme=dark; passport-app=def', { + hostPrefixCookies: ['access_token', /^passport-/], + useHostPrefix: true, + }) + + // Assert + expect(cookieHeader).toBe('__Host-access_token=abc; theme=dark; __Host-passport-app=def') + }) + + // Scenario: HTTP upstreams should keep local cookie names even when rewrite config exists. + it('should keep local cookie names for HTTP upstream requests', () => { + // Act + const cookieHeader = rewriteCookieHeaderForUpstream('access_token=abc; refresh_token=def', { + hostPrefixCookies: ['access_token', 'refresh_token'], + useHostPrefix: false, + }) + + // Assert + expect(cookieHeader).toBe('access_token=abc; refresh_token=def') + }) + + // Scenario: upstream set-cookie headers should be converted into localhost-safe cookies. + it('should rewrite upstream set-cookie headers for local development', () => { + // Act + const cookies = rewriteSetCookieHeadersForLocal([ + '__Host-access_token=abc; Path=/console/api; Domain=cloud.example.com; Secure; SameSite=None; Partitioned', + ]) + + // Assert + expect(cookies).toEqual([ + 'access_token=abc; Path=/; SameSite=Lax', + ]) + }) +}) diff --git a/web/plugins/dev-proxy/cookies.ts b/packages/dev-proxy/src/cookies.ts similarity index 61% rename from web/plugins/dev-proxy/cookies.ts rename to packages/dev-proxy/src/cookies.ts index ad087d1549..61fdb6abd4 100644 --- a/web/plugins/dev-proxy/cookies.ts +++ b/packages/dev-proxy/src/cookies.ts @@ -1,4 +1,4 @@ -const DEFAULT_PROXY_TARGET = 'https://cloud.dify.ai' +import type { CookieRewriteOptions } from './types' const SECURE_COOKIE_PREFIX_PATTERN = /^__(Host|Secure)-/ const SAME_SITE_NONE_PATTERN = /^samesite=none$/i @@ -7,38 +7,37 @@ const COOKIE_DOMAIN_PATTERN = /^domain=/i const COOKIE_SECURE_PATTERN = /^secure$/i const COOKIE_PARTITIONED_PATTERN = /^partitioned$/i -const HOST_PREFIX_COOKIE_NAMES = new Set([ - 'access_token', - 'csrf_token', - 'refresh_token', - 'webapp_access_token', -]) +const stripSecureCookiePrefix = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '') -const isPassportCookie = (cookieName: string) => cookieName.startsWith('passport-') +const matchesCookieName = (cookieName: string, matcher: string | RegExp) => + typeof matcher === 'string' + ? matcher === cookieName + : matcher.test(cookieName) -const shouldUseHostPrefix = (cookieName: string) => { - const normalizedCookieName = cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '') - return HOST_PREFIX_COOKIE_NAMES.has(normalizedCookieName) || isPassportCookie(normalizedCookieName) +const shouldUseHostPrefix = (cookieName: string, options: CookieRewriteOptions) => { + const normalizedCookieName = stripSecureCookiePrefix(cookieName) + + return options.hostPrefixCookies?.some(matcher => matchesCookieName(normalizedCookieName, matcher)) || false } -const toUpstreamCookieName = (cookieName: string) => { +const toUpstreamCookieName = (cookieName: string, options: CookieRewriteOptions) => { if (cookieName.startsWith('__Host-')) return cookieName if (cookieName.startsWith('__Secure-')) - return `__Host-${cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')}` + return `__Host-${stripSecureCookiePrefix(cookieName)}` - if (!shouldUseHostPrefix(cookieName)) + if (!shouldUseHostPrefix(cookieName, options)) return cookieName return `__Host-${cookieName}` } -const toLocalCookieName = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '') +export const toLocalCookieName = (cookieName: string) => stripSecureCookiePrefix(cookieName) export const rewriteCookieHeaderForUpstream = ( - cookieHeader?: string, - options: { useHostPrefix?: boolean } = {}, + cookieHeader: string | undefined, + options: CookieRewriteOptions & { useHostPrefix?: boolean }, ) => { if (!cookieHeader) return cookieHeader @@ -55,7 +54,11 @@ export const rewriteCookieHeaderForUpstream = ( const cookieName = cookie.slice(0, separatorIndex).trim() const cookieValue = cookie.slice(separatorIndex + 1) - return `${useHostPrefix ? toUpstreamCookieName(cookieName) : cookieName}=${cookieValue}` + const upstreamCookieName = useHostPrefix + ? toUpstreamCookieName(cookieName, options) + : cookieName + + return `${upstreamCookieName}=${cookieValue}` }) .join('; ') } @@ -89,15 +92,5 @@ const rewriteSetCookieValueForLocal = (setCookieValue: string) => { return [`${toLocalCookieName(cookieName)}=${cookieValue}`, ...rewrittenAttributes].join('; ') } -export const rewriteSetCookieHeadersForLocal = (setCookieHeaders?: string | string[]): string[] | undefined => { - if (!setCookieHeaders) - return undefined - - const normalizedHeaders = Array.isArray(setCookieHeaders) - ? setCookieHeaders - : [setCookieHeaders] - - return normalizedHeaders.map(rewriteSetCookieValueForLocal) -} - -export { DEFAULT_PROXY_TARGET } +export const rewriteSetCookieHeadersForLocal = (setCookieHeaders: readonly string[]) => + setCookieHeaders.map(rewriteSetCookieValueForLocal) diff --git a/packages/dev-proxy/src/index.ts b/packages/dev-proxy/src/index.ts new file mode 100644 index 0000000000..e35893b98f --- /dev/null +++ b/packages/dev-proxy/src/index.ts @@ -0,0 +1,22 @@ +export { + assertDevProxyConfig, + defineDevProxyConfig, + loadDevProxyConfig, + parseDevProxyCliArgs, + resolveDevProxyServerOptions, +} from './config' +export { rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal, toLocalCookieName } from './cookies' +export { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, isAllowedLocalDevOrigin } from './server' +export type { + CookieNameMatcher, + CookieRewriteOptions, + CreateDevProxyAppOptions, + DevProxyCliOptions, + DevProxyConfig, + DevProxyConfigLoadOptions, + DevProxyCorsAllowedOrigins, + DevProxyCorsConfig, + DevProxyRoute, + DevProxyServerConfig, + ResolvedDevProxyServerOptions, +} from './types' diff --git a/web/plugins/dev-proxy/server.spec.ts b/packages/dev-proxy/src/server.spec.ts similarity index 54% rename from web/plugins/dev-proxy/server.spec.ts rename to packages/dev-proxy/src/server.spec.ts index 4b3344be42..32c16a1807 100644 --- a/web/plugins/dev-proxy/server.spec.ts +++ b/packages/dev-proxy/src/server.spec.ts @@ -2,41 +2,13 @@ * @vitest-environment node */ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server' +import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin } from './server' describe('dev proxy server', () => { beforeEach(() => { vi.clearAllMocks() }) - // Scenario: Hono proxy targets should be read directly from env. - it('should resolve Hono proxy targets from env', () => { - // Arrange - const targets = resolveDevProxyTargets({ - HONO_CONSOLE_API_PROXY_TARGET: 'https://console.example.com', - HONO_PUBLIC_API_PROXY_TARGET: 'https://public.example.com', - HONO_ENTERPRISE_API_PROXY_TARGET: 'https://enterprise.example.com', - }) - - // Assert - expect(targets.consoleApiTarget).toBe('https://console.example.com') - expect(targets.publicApiTarget).toBe('https://public.example.com') - expect(targets.enterpriseApiTarget).toBe('https://enterprise.example.com') - }) - - // Scenario: optional proxy targets should use their route-specific defaults. - it('should use console target as the default for optional targets', () => { - // Act - const targets = resolveDevProxyTargets({ - HONO_CONSOLE_API_PROXY_TARGET: 'https://console.example.com', - }) - - // Assert - expect(targets.consoleApiTarget).toBe('https://console.example.com') - expect(targets.publicApiTarget).toBe('https://console.example.com') - expect(targets.enterpriseApiTarget).toBeUndefined() - }) - // Scenario: target paths should not be duplicated when the incoming route already includes them. it('should preserve prefixed targets when building upstream URLs', () => { // Act @@ -46,30 +18,43 @@ describe('dev proxy server', () => { expect(url.href).toBe('https://api.example.com/console/api/apps?page=1') }) - // Scenario: only localhost dev origins should be reflected for credentialed CORS. - it('should only allow local development origins', () => { + // Scenario: only localhost dev origins should be reflected for credentialed CORS by default. + it('should only allow local development origins by default', () => { // Assert expect(isAllowedDevOrigin('http://localhost:3000')).toBe(true) expect(isAllowedDevOrigin('http://127.0.0.1:3000')).toBe(true) expect(isAllowedDevOrigin('https://example.com')).toBe(false) }) - // Scenario: proxy requests should rewrite cookies and surface credentialed CORS headers. - it('should proxy api requests through Hono with local cookie rewriting', async () => { + // Scenario: explicit CORS origins should support non-local development hosts. + it('should allow explicitly configured origins', () => { + // Assert + expect(isAllowedDevOrigin('https://app.example.com', ['https://app.example.com'])).toBe(true) + expect(isAllowedDevOrigin('https://other.example.com', ['https://app.example.com'])).toBe(false) + }) + + // Scenario: proxy requests should rewrite cookies and surface credentialed CORS headers when configured. + it('should proxy api requests with configured local cookie rewriting', async () => { // Arrange const fetchImpl = vi.fn().mockResolvedValue(new Response('ok', { status: 200, headers: [ ['content-encoding', 'br'], ['content-length', '123'], - ['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None'], + ['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.example.com; Secure; SameSite=None'], ['transfer-encoding', 'chunked'], ], })) const app = createDevProxyApp({ - consoleApiTarget: 'https://cloud.dify.ai', - publicApiTarget: 'https://public.dify.ai', - enterpriseApiTarget: 'https://enterprise.dify.ai', + routes: [ + { + paths: '/console/api', + target: 'https://cloud.example.com', + cookieRewrite: { + hostPrefixCookies: ['access_token'], + }, + }, + ], fetchImpl, }) @@ -77,7 +62,7 @@ describe('dev proxy server', () => { const response = await app.request('http://127.0.0.1:5001/console/api/apps?page=1', { headers: { 'Origin': 'http://localhost:3000', - 'Cookie': 'access_token=abc', + 'Cookie': 'access_token=abc; theme=dark', 'Accept-Encoding': 'zstd, br, gzip', }, }) @@ -85,7 +70,7 @@ describe('dev proxy server', () => { // Assert expect(fetchImpl).toHaveBeenCalledTimes(1) expect(fetchImpl).toHaveBeenCalledWith( - new URL('https://cloud.dify.ai/console/api/apps?page=1'), + new URL('https://cloud.example.com/console/api/apps?page=1'), expect.objectContaining({ method: 'GET', headers: expect.any(Headers), @@ -96,8 +81,8 @@ describe('dev proxy server', () => { if (!(requestHeaders instanceof Headers)) throw new Error('Expected proxy request headers to be Headers') - expect(requestHeaders.get('cookie')).toBe('__Host-access_token=abc') - expect(requestHeaders.get('origin')).toBe('https://cloud.dify.ai') + expect(requestHeaders.get('cookie')).toBe('__Host-access_token=abc; theme=dark') + expect(requestHeaders.get('origin')).toBe('https://cloud.example.com') expect(requestHeaders.get('accept-encoding')).toBe('identity') expect(response.headers.get('access-control-allow-origin')).toBe('http://localhost:3000') expect(response.headers.get('access-control-allow-credentials')).toBe('true') @@ -109,14 +94,49 @@ describe('dev proxy server', () => { ]) }) - // Scenario: a local HTTP Dify API expects the non-prefixed local cookie name. + // Scenario: generic proxy routes should not know Dify cookie names by default. + it('should not rewrite cookie names when cookie rewriting is not configured', async () => { + // Arrange + const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) + const app = createDevProxyApp({ + routes: [ + { + paths: '/api', + target: 'https://api.example.com', + }, + ], + fetchImpl, + }) + + // Act + await app.request('http://127.0.0.1:5001/api/messages', { + headers: { + Cookie: 'access_token=abc; refresh_token=def', + }, + }) + + // Assert + const requestHeaders = fetchImpl.mock.calls[0]?.[1]?.headers + if (!(requestHeaders instanceof Headers)) + throw new Error('Expected proxy request headers to be Headers') + + expect(requestHeaders.get('cookie')).toBe('access_token=abc; refresh_token=def') + }) + + // Scenario: local HTTP upstreams expect local cookie names even when cookie rewriting is configured. it('should keep local cookie names for HTTP upstream targets', async () => { // Arrange const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) const app = createDevProxyApp({ - consoleApiTarget: 'http://127.0.0.1:5001', - publicApiTarget: 'http://127.0.0.1:5001', - enterpriseApiTarget: 'http://127.0.0.1:8082', + routes: [ + { + paths: '/console/api', + target: 'http://127.0.0.1:5001', + cookieRewrite: { + hostPrefixCookies: ['access_token', 'refresh_token'], + }, + }, + ], fetchImpl, }) @@ -135,47 +155,59 @@ describe('dev proxy server', () => { expect(requestHeaders.get('cookie')).toBe('access_token=abc; refresh_token=def') }) - // Scenario: Enterprise dashboard routes should use the Enterprise target before generic API routes. - it('should proxy enterprise api routes to the enterprise target', async () => { + // Scenario: custom route paths should support independent upstream targets. + it('should proxy custom route paths to their configured targets', async () => { // Arrange const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) const app = createDevProxyApp({ - consoleApiTarget: 'https://console.example.com', - publicApiTarget: 'https://public.example.com', - enterpriseApiTarget: 'https://enterprise.example.com', + routes: [ + { + paths: '/api', + target: 'https://api.example.com', + }, + { + paths: '/files', + target: 'https://files.example.com/assets', + }, + ], fetchImpl, }) - const requestUrls = [ - 'http://127.0.0.1:5001/console/api/enterprise/sso/saml/login', - 'http://127.0.0.1:5001/api/enterprise/sso/oauth2/login', - 'http://127.0.0.1:5001/admin-api/v1/workspaces', - 'http://127.0.0.1:5001/inner/api/info', - 'http://127.0.0.1:5001/mfa/v1/verify', - 'http://127.0.0.1:5001/scim/v2/Users', - 'http://127.0.0.1:5001/v1/audit/logs', - 'http://127.0.0.1:5001/v1/dashboard/api/license/status', - 'http://127.0.0.1:5001/v1/healthz', - 'http://127.0.0.1:5001/v1/plugin-manager/plugins', - ] - // Act - for (const url of requestUrls) - await app.request(url) + await app.request('http://127.0.0.1:5001/api/messages') + await app.request('http://127.0.0.1:5001/files/logo.png?size=small') // Assert - expect(fetchImpl).toHaveBeenCalledTimes(requestUrls.length) expect(fetchImpl.mock.calls.map(([url]) => url.toString())).toEqual([ - 'https://enterprise.example.com/console/api/enterprise/sso/saml/login', - 'https://enterprise.example.com/api/enterprise/sso/oauth2/login', - 'https://enterprise.example.com/admin-api/v1/workspaces', - 'https://enterprise.example.com/inner/api/info', - 'https://enterprise.example.com/mfa/v1/verify', - 'https://enterprise.example.com/scim/v2/Users', - 'https://enterprise.example.com/v1/audit/logs', - 'https://enterprise.example.com/v1/dashboard/api/license/status', - 'https://enterprise.example.com/v1/healthz', - 'https://enterprise.example.com/v1/plugin-manager/plugins', + 'https://api.example.com/api/messages', + 'https://files.example.com/assets/files/logo.png?size=small', + ]) + }) + + // Scenario: routes are matched in config order so callers can put specific routes first. + it('should prefer earlier route entries', async () => { + // Arrange + const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) + const app = createDevProxyApp({ + routes: [ + { + paths: '/api/enterprise', + target: 'https://enterprise.example.com', + }, + { + paths: '/api', + target: 'https://api.example.com', + }, + ], + fetchImpl, + }) + + // Act + await app.request('http://127.0.0.1:5001/api/enterprise/sso/login') + + // Assert + expect(fetchImpl.mock.calls.map(([url]) => url.toString())).toEqual([ + 'https://enterprise.example.com/api/enterprise/sso/login', ]) }) @@ -183,9 +215,12 @@ describe('dev proxy server', () => { it('should answer CORS preflight requests', async () => { // Arrange const app = createDevProxyApp({ - consoleApiTarget: 'https://cloud.dify.ai', - publicApiTarget: 'https://public.dify.ai', - enterpriseApiTarget: 'https://enterprise.dify.ai', + routes: [ + { + paths: '/api', + target: 'https://api.example.com', + }, + ], fetchImpl: vi.fn(), }) diff --git a/web/plugins/dev-proxy/server.ts b/packages/dev-proxy/src/server.ts similarity index 52% rename from web/plugins/dev-proxy/server.ts rename to packages/dev-proxy/src/server.ts index e4867b6077..79654750da 100644 --- a/web/plugins/dev-proxy/server.ts +++ b/packages/dev-proxy/src/server.ts @@ -1,25 +1,9 @@ import type { Context, Hono } from 'hono' +import type { CookieRewriteOptions, CreateDevProxyAppOptions, DevProxyCorsAllowedOrigins, DevProxyRoute } from './types' import { Hono as HonoApp } from 'hono' -import { DEFAULT_PROXY_TARGET, rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies' +import { rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies' -type DevProxyEnv = Partial> - -type DevProxyTargets = { - consoleApiTarget: string - publicApiTarget: string - enterpriseApiTarget?: string -} - -type DevProxyAppOptions = DevProxyTargets & { - fetchImpl?: typeof globalThis.fetch -} - -const LOCAL_DEV_HOSTS = new Set(['localhost', '127.0.0.1', '[::1]']) +const LOCAL_DEV_HOSTS = new Set(['localhost', '127.0.0.1', '[::1]', '::1']) const ALLOW_METHODS = 'GET,HEAD,POST,PUT,PATCH,DELETE,OPTIONS' const DEFAULT_ALLOW_HEADERS = 'Authorization, Content-Type, X-CSRF-Token' const UPSTREAM_ACCEPT_ENCODING = 'identity' @@ -28,31 +12,14 @@ const RESPONSE_HEADERS_TO_DROP = [ 'content-encoding', 'content-length', 'keep-alive', - 'set-cookie', + 'proxy-authenticate', + 'proxy-authorization', + 'te', + 'trailer', 'transfer-encoding', + 'upgrade', ] as const -const ENTERPRISE_API_ROUTES = [ - '/console/api/enterprise', - '/api/enterprise', - '/admin-api', - '/inner/api', - '/mfa', - '/scim', - '/v1/audit', - '/v1/dashboard', - '/v1/healthz', - '/v1/plugin-manager', -] as const - -const CONSOLE_API_ROUTES = ['/console/api'] as const -const PUBLIC_API_ROUTES = ['/api'] as const - -type ProxyRoutePath - = | typeof ENTERPRISE_API_ROUTES[number] - | typeof CONSOLE_API_ROUTES[number] - | typeof PUBLIC_API_ROUTES[number] - const appendHeaderValue = (headers: Headers, name: string, value: string) => { const currentValue = headers.get(name) if (!currentValue) { @@ -66,7 +33,7 @@ const appendHeaderValue = (headers: Headers, name: string, value: string) => { headers.set(name, `${currentValue}, ${value}`) } -export const isAllowedDevOrigin = (origin?: string | null) => { +export const isAllowedLocalDevOrigin = (origin?: string | null) => { if (!origin) return false @@ -79,8 +46,25 @@ export const isAllowedDevOrigin = (origin?: string | null) => { } } -const applyCorsHeaders = (headers: Headers, origin?: string | null) => { - if (!isAllowedDevOrigin(origin)) +export const isAllowedDevOrigin = ( + origin?: string | null, + allowedOrigins: DevProxyCorsAllowedOrigins = 'local', +) => { + if (!origin) + return false + + if (allowedOrigins === 'local') + return isAllowedLocalDevOrigin(origin) + + return allowedOrigins.includes(origin) +} + +const applyCorsHeaders = ( + headers: Headers, + origin: string | undefined | null, + allowedOrigins: DevProxyCorsAllowedOrigins = 'local', +) => { + if (!isAllowedDevOrigin(origin, allowedOrigins)) return headers.set('Access-Control-Allow-Origin', origin!) @@ -103,7 +87,11 @@ export const buildUpstreamUrl = (target: string, requestPath: string, search = ' return targetUrl } -const createProxyRequestHeaders = (request: Request, targetUrl: URL) => { +const createProxyRequestHeaders = ( + request: Request, + targetUrl: URL, + cookieRewrite: CookieRewriteOptions | false | undefined, +) => { const headers = new Headers(request.headers) headers.delete('host') headers.set('accept-encoding', UPSTREAM_ACCEPT_ENCODING) @@ -111,36 +99,60 @@ const createProxyRequestHeaders = (request: Request, targetUrl: URL) => { if (headers.has('origin')) headers.set('origin', targetUrl.origin) - const rewrittenCookieHeader = rewriteCookieHeaderForUpstream(headers.get('cookie') || undefined, { - useHostPrefix: targetUrl.protocol === 'https:', - }) - if (rewrittenCookieHeader) - headers.set('cookie', rewrittenCookieHeader) + if (cookieRewrite) { + const rewrittenCookieHeader = rewriteCookieHeaderForUpstream(headers.get('cookie') || undefined, { + ...cookieRewrite, + useHostPrefix: targetUrl.protocol === 'https:', + }) + if (rewrittenCookieHeader) + headers.set('cookie', rewrittenCookieHeader) + } return headers } -const createUpstreamResponseHeaders = (response: Response, requestOrigin?: string | null) => { +const getSetCookieHeaders = (headers: Headers) => { + const headersWithGetSetCookie = headers as Headers & { getSetCookie?: () => string[] } + const setCookieHeaders = headersWithGetSetCookie.getSetCookie?.() + if (setCookieHeaders?.length) + return setCookieHeaders + + const setCookie = headers.get('set-cookie') + return setCookie ? [setCookie] : [] +} + +const createUpstreamResponseHeaders = ( + response: Response, + requestOrigin: string | undefined | null, + allowedOrigins: DevProxyCorsAllowedOrigins, + cookieRewrite: CookieRewriteOptions | false | undefined, +) => { const headers = new Headers(response.headers) RESPONSE_HEADERS_TO_DROP.forEach(header => headers.delete(header)) + headers.delete('set-cookie') - const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie()) - rewrittenSetCookies?.forEach((cookie) => { + const setCookieHeaders = getSetCookieHeaders(response.headers) + const responseSetCookieHeaders = cookieRewrite + ? rewriteSetCookieHeadersForLocal(setCookieHeaders) + : setCookieHeaders + + responseSetCookieHeaders.forEach((cookie) => { headers.append('set-cookie', cookie) }) - applyCorsHeaders(headers, requestOrigin) + applyCorsHeaders(headers, requestOrigin, allowedOrigins) return headers } const proxyRequest = async ( context: Context, - target: string, + route: DevProxyRoute, fetchImpl: typeof globalThis.fetch, + allowedOrigins: DevProxyCorsAllowedOrigins, ) => { const requestUrl = new URL(context.req.url) - const targetUrl = buildUpstreamUrl(target, requestUrl.pathname, requestUrl.search) - const requestHeaders = createProxyRequestHeaders(context.req.raw, targetUrl) + const targetUrl = buildUpstreamUrl(route.target, requestUrl.pathname, requestUrl.search) + const requestHeaders = createProxyRequestHeaders(context.req.raw, targetUrl, route.cookieRewrite) const requestInit: RequestInit & { duplex?: 'half' } = { method: context.req.method, headers: requestHeaders, @@ -153,7 +165,12 @@ const proxyRequest = async ( } const upstreamResponse = await fetchImpl(targetUrl, requestInit) - const responseHeaders = createUpstreamResponseHeaders(upstreamResponse, context.req.header('origin')) + const responseHeaders = createUpstreamResponseHeaders( + upstreamResponse, + context.req.header('origin'), + allowedOrigins, + route.cookieRewrite, + ) return new Response(upstreamResponse.body, { status: upstreamResponse.status, @@ -162,48 +179,46 @@ const proxyRequest = async ( }) } +const normalizeRoutePaths = (paths: DevProxyRoute['paths']) => Array.isArray(paths) ? paths : [paths] + const registerProxyRoute = ( app: Hono, - path: ProxyRoutePath, - target: string, + route: DevProxyRoute, + path: string, fetchImpl: typeof globalThis.fetch, + allowedOrigins: DevProxyCorsAllowedOrigins, ) => { - app.all(path, context => proxyRequest(context, target, fetchImpl)) - app.all(`${path}/*`, context => proxyRequest(context, target, fetchImpl)) + if (!path.startsWith('/')) + throw new Error(`Invalid dev proxy route path "${path}". Paths must start with "/".`) + + app.all(path, context => proxyRequest(context, route, fetchImpl, allowedOrigins)) + app.all(`${path}/*`, context => proxyRequest(context, route, fetchImpl, allowedOrigins)) } const registerProxyRoutes = ( app: Hono, - routes: readonly ProxyRoutePath[], - target: string, + routes: readonly DevProxyRoute[], fetchImpl: typeof globalThis.fetch, + allowedOrigins: DevProxyCorsAllowedOrigins, ) => { - routes.forEach(route => registerProxyRoute(app, route, target, fetchImpl)) + routes.forEach((route) => { + normalizeRoutePaths(route.paths).forEach((path) => { + registerProxyRoute(app, route, path, fetchImpl, allowedOrigins) + }) + }) } -export const resolveDevProxyTargets = (env: DevProxyEnv = {}): DevProxyTargets => { - const consoleApiTarget = env.HONO_CONSOLE_API_PROXY_TARGET - || DEFAULT_PROXY_TARGET - const publicApiTarget = env.HONO_PUBLIC_API_PROXY_TARGET - || consoleApiTarget - const enterpriseApiTarget = env.HONO_ENTERPRISE_API_PROXY_TARGET - - return { - consoleApiTarget, - publicApiTarget, - enterpriseApiTarget, - } -} - -export const createDevProxyApp = (options: DevProxyAppOptions) => { +export const createDevProxyApp = (options: CreateDevProxyAppOptions) => { const app = new HonoApp() const fetchImpl = options.fetchImpl || globalThis.fetch + const logger = options.logger || console + const allowedOrigins = options.cors?.allowedOrigins || 'local' app.onError((error, context) => { - console.error('[dev-hono-proxy]', error) + logger.error('[dev-proxy]', error) const headers = new Headers() - applyCorsHeaders(headers, context.req.header('origin')) + applyCorsHeaders(headers, context.req.header('origin'), allowedOrigins) return new Response('Upstream proxy request failed.', { status: 502, @@ -214,7 +229,7 @@ export const createDevProxyApp = (options: DevProxyAppOptions) => { app.use('*', async (context, next) => { if (context.req.method === 'OPTIONS') { const headers = new Headers() - applyCorsHeaders(headers, context.req.header('origin')) + applyCorsHeaders(headers, context.req.header('origin'), allowedOrigins) headers.set('Access-Control-Allow-Methods', ALLOW_METHODS) headers.set( 'Access-Control-Allow-Headers', @@ -230,13 +245,10 @@ export const createDevProxyApp = (options: DevProxyAppOptions) => { } await next() - applyCorsHeaders(context.res.headers, context.req.header('origin')) + applyCorsHeaders(context.res.headers, context.req.header('origin'), allowedOrigins) }) - if (options.enterpriseApiTarget) - registerProxyRoutes(app, ENTERPRISE_API_ROUTES, options.enterpriseApiTarget, fetchImpl) - registerProxyRoutes(app, CONSOLE_API_ROUTES, options.consoleApiTarget, fetchImpl) - registerProxyRoutes(app, PUBLIC_API_ROUTES, options.publicApiTarget, fetchImpl) + registerProxyRoutes(app, options.routes, fetchImpl, allowedOrigins) return app } diff --git a/packages/dev-proxy/src/types.ts b/packages/dev-proxy/src/types.ts new file mode 100644 index 0000000000..2c42b2f7fb --- /dev/null +++ b/packages/dev-proxy/src/types.ts @@ -0,0 +1,50 @@ +export type DevProxyServerConfig = { + host?: string + port?: number +} + +export type DevProxyCorsAllowedOrigins = 'local' | readonly string[] + +export type DevProxyCorsConfig = { + allowedOrigins?: DevProxyCorsAllowedOrigins +} + +export type CookieNameMatcher = string | RegExp + +export type CookieRewriteOptions = { + hostPrefixCookies?: readonly CookieNameMatcher[] +} + +export type DevProxyRoute = { + paths: string | readonly string[] + target: string + cookieRewrite?: CookieRewriteOptions | false +} + +export type DevProxyConfig = { + server?: DevProxyServerConfig + routes: readonly DevProxyRoute[] + cors?: DevProxyCorsConfig +} + +export type DevProxyCliOptions = { + config?: string + envFile?: string + host?: string + port?: string + help?: boolean +} + +export type DevProxyConfigLoadOptions = { + envFile?: string | false +} + +export type ResolvedDevProxyServerOptions = { + host: string + port: number +} + +export type CreateDevProxyAppOptions = Pick & { + fetchImpl?: typeof globalThis.fetch + logger?: Pick +} diff --git a/packages/dev-proxy/tsconfig.json b/packages/dev-proxy/tsconfig.json new file mode 100644 index 0000000000..813a9bd8a3 --- /dev/null +++ b/packages/dev-proxy/tsconfig.json @@ -0,0 +1,17 @@ +{ + "extends": "@dify/tsconfig/node.json", + "compilerOptions": { + "types": [ + "node", + "vitest/globals" + ] + }, + "include": [ + "src/**/*.ts", + "vite.config.ts" + ], + "exclude": [ + "node_modules", + "dist" + ] +} diff --git a/packages/dev-proxy/vite.config.ts b/packages/dev-proxy/vite.config.ts new file mode 100644 index 0000000000..d060ae036e --- /dev/null +++ b/packages/dev-proxy/vite.config.ts @@ -0,0 +1,27 @@ +import { defineConfig } from 'vite-plus' + +export default defineConfig({ + pack: { + clean: true, + deps: { + neverBundle: [ + '@hono/node-server', + 'c12', + 'hono', + ], + }, + entry: [ + 'src/index.ts', + 'src/cli.ts', + ], + format: ['esm'], + outDir: 'dist', + platform: 'node', + sourcemap: true, + target: 'node22', + treeshake: true, + }, + test: { + environment: 'node', + }, +}) diff --git a/packages/dify-ui/AGENTS.md b/packages/dify-ui/AGENTS.md index d8a59b7a0b..bdc2160702 100644 --- a/packages/dify-ui/AGENTS.md +++ b/packages/dify-ui/AGENTS.md @@ -56,4 +56,28 @@ The Figma design system uses `--radius/*` tokens whose scale is **offset by one - When the Figma MCP returns `rounded-[var(--radius/sm, 6px)]`, convert it to the standard Tailwind class from the table above (e.g. `rounded-md`). - For values without a standard Tailwind equivalent (10px, 20px, 28px), use arbitrary values like `rounded-[10px]`. +## Search / Picker Primitive Selection: Autocomplete vs Combobox vs Select + +Pick by whether the user is entering free-form text, choosing a remembered value, or selecting from a closed list. + +Base UI decision rules: + +- [Autocomplete docs]: use `Combobox` instead of `Autocomplete` if the selection should be remembered and the input value cannot be custom. +- [Combobox docs]: do not use `Combobox` for simple search widgets that require unrestricted text entry; use `Autocomplete` instead. + +Apply this split in Dify UI: + +- `Autocomplete` — free-form text input with optional suggestions or completions. The input value may be custom and does not necessarily become a selected option. Use for search boxes, command-style suggestions, tag suggestions, and async text completion. +- `Combobox` — searchable picker whose value is one or more selected items from a collection. The chosen value is remembered by the root, and free-form text is not the final value. Use for model pickers, user pickers, dataset/document pickers, and multi-select chips. +- `Select` — closed-list picker without text entry. Use when the option set is small or already scannable and filtering is unnecessary. + +Composition rules: + +- Keep Base UI primitive semantics visible in the public API. Export compound parts such as `ComboboxInputGroup`, `ComboboxInput`, `ComboboxContent`, `ComboboxList`, `ComboboxItem`, and `ComboboxItemIndicator` instead of wrapping them into one business component. +- For `Combobox` multiple selection, follow the official chips pattern: `ComboboxInputGroup` contains `ComboboxChips`, `ComboboxValue` renders `ComboboxChip` items, and `ComboboxInput` remains inside the chips row. Chips should wrap and let the input group grow vertically instead of forcing horizontal overflow. +- Content primitives must own their Base UI `Portal` and use `z-1002` on `Positioner`, matching the overlay contract in `README.md`. +- Use `w-(--anchor-width)` with viewport-aware max-width for `Autocomplete` and `Combobox` popups. Do not add `min-w-(--anchor-width)` when it would defeat available-width clamping. + +[Autocomplete docs]: https://base-ui.com/react/components/autocomplete.md#usage-guidelines +[Combobox docs]: https://base-ui.com/react/components/combobox.md#usage-guidelines [docs]: https://base-ui.com/react/components/tooltip#infotips diff --git a/packages/dify-ui/README.md b/packages/dify-ui/README.md index cd24a0c078..41e99d0952 100644 --- a/packages/dify-ui/README.md +++ b/packages/dify-ui/README.md @@ -36,12 +36,12 @@ Importing from `@langgenius/dify-ui` (no subpath) is intentionally not supported ## Primitives -| Category | Subpath | Notes | -| -------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------- | -| Overlay | `./alert-dialog`, `./context-menu`, `./dialog`, `./dropdown-menu`, `./popover`, `./select`, `./toast`, `./tooltip` | Portalled. See [Overlay & portal contract] below. | -| Form | `./number-field`, `./slider`, `./switch` | Controlled / uncontrolled per Base UI defaults. | -| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. | -| Media | `./avatar`, `./button` | Button exposes `cva` variants. | +| Category | Subpath | Notes | +| -------- | -------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------- | +| Overlay | `./alert-dialog`, `./autocomplete`, `./combobox`, `./context-menu`, `./dialog`, `./dropdown-menu`, `./popover`, `./select`, `./toast`, `./tooltip` | Portalled. See [Overlay & portal contract] below. | +| Form | `./autocomplete`, `./combobox`, `./number-field`, `./slider`, `./switch` | Controlled / uncontrolled per Base UI defaults. | +| Layout | `./scroll-area` | Custom-styled scrollbar over the host viewport. | +| Media | `./avatar`, `./button` | Button exposes `cva` variants. | Utilities: @@ -65,7 +65,7 @@ If a consumer uses Dify UI source files through the workspace, add an explicit s ## Overlay & portal contract -All overlay primitives (`dialog`, `alert-dialog`, `popover`, `dropdown-menu`, `context-menu`, `select`, `tooltip`, `toast`) render their content inside a [Base UI Portal] attached to `document.body`. This is the Base UI default — see the upstream [Portals][Base UI Portal] docs for the underlying behavior. Consumers **do not** need to wrap anything in a portal manually. +All overlay primitives (`dialog`, `alert-dialog`, `autocomplete`, `combobox`, `popover`, `dropdown-menu`, `context-menu`, `select`, `tooltip`, `toast`) render their content inside a [Base UI Portal] attached to `document.body`. This is the Base UI default — see the upstream [Portals][Base UI Portal] docs for the underlying behavior. Consumers **do not** need to wrap anything in a portal manually. ### Root isolation requirement @@ -83,10 +83,10 @@ Equivalent: any root element with `isolation: isolate` in CSS. Without it, overl Every overlay primitive uses a single, shared z-index. Do **not** override it at call sites. -| Layer | z-index | Where | -| ----------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- | -| Overlays (Dialog, AlertDialog, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop | -| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. | +| Layer | z-index | Where | +| ----------------------------------------------------------------------------------------------------------- | -------- | -------------------------------------------------------------------------- | +| Overlays (Dialog, AlertDialog, Autocomplete, Combobox, Popover, DropdownMenu, ContextMenu, Select, Tooltip) | `z-1002` | Positioner / Backdrop | +| Toast viewport | `z-1003` | One layer above overlays so notifications are never hidden under a dialog. | Rationale: during Dify's migration from legacy `portal-to-follow-elem` / `base/modal` / `base/dialog` overlays to this package, new and old overlays coexist in the DOM. `z-1002` sits above any common legacy layer, eliminating per-call-site z-index hacks. Among themselves, new primitives share the same z-index and **rely on DOM order** for stacking — the portal mounted later wins. diff --git a/packages/dify-ui/package.json b/packages/dify-ui/package.json index 73c6c0bd22..20e94c7dee 100644 --- a/packages/dify-ui/package.json +++ b/packages/dify-ui/package.json @@ -13,6 +13,10 @@ "types": "./src/alert-dialog/index.tsx", "import": "./src/alert-dialog/index.tsx" }, + "./autocomplete": { + "types": "./src/autocomplete/index.tsx", + "import": "./src/autocomplete/index.tsx" + }, "./avatar": { "types": "./src/avatar/index.tsx", "import": "./src/avatar/index.tsx" @@ -21,6 +25,10 @@ "types": "./src/button/index.tsx", "import": "./src/button/index.tsx" }, + "./combobox": { + "types": "./src/combobox/index.tsx", + "import": "./src/combobox/index.tsx" + }, "./context-menu": { "types": "./src/context-menu/index.tsx", "import": "./src/context-menu/index.tsx" @@ -103,6 +111,7 @@ "@storybook/addon-themes": "catalog:", "@storybook/react-vite": "catalog:", "@tailwindcss/vite": "catalog:", + "@tanstack/react-virtual": "catalog:", "@types/react": "catalog:", "@types/react-dom": "catalog:", "@typescript/native-preview": "catalog:", diff --git a/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx new file mode 100644 index 0000000000..a7031c5b12 --- /dev/null +++ b/packages/dify-ui/src/autocomplete/__tests__/index.spec.tsx @@ -0,0 +1,252 @@ +import type { ReactNode } from 'react' +import { render } from 'vitest-browser-react' +import { + Autocomplete, + AutocompleteClear, + AutocompleteContent, + AutocompleteEmpty, + AutocompleteGroup, + AutocompleteInput, + AutocompleteInputGroup, + AutocompleteItem, + AutocompleteItemIndicator, + AutocompleteItemText, + AutocompleteLabel, + AutocompleteList, + AutocompleteSeparator, + AutocompleteStatus, + AutocompleteTrigger, +} from '../index' + +const renderWithSafeViewport = (ui: ReactNode) => render( +
+ {ui} +
, +) + +const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement + +const renderAutocomplete = ({ + children, + open = false, + defaultValue = 'workflow', +}: { + children?: ReactNode + open?: boolean + defaultValue?: string +} = {}) => renderWithSafeViewport( + + {children ?? ( + <> + + + + + + + 2 suggestions + + + Workflow + + + + Dataset + + + No suggestions + + + )} + , +) + +describe('Autocomplete wrappers', () => { + describe('Input group and input', () => { + it('should apply medium input group and input classes by default', async () => { + const screen = await renderAutocomplete() + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-lg') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('px-3') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('system-sm-regular') + }) + + it('should apply large input group and input classes when large size is provided', async () => { + const screen = await renderAutocomplete({ + children: ( + + + + ), + }) + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-[10px]') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('px-4') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('system-md-regular') + }) + + it('should set input defaults and forward passthrough props', async () => { + const screen = await renderAutocomplete({ + children: ( + + + + ), + }) + + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('autocomplete', 'off') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('type', 'text') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveAttribute('placeholder', 'Find a resource') + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toBeRequired() + await expect.element(screen.getByRole('combobox', { name: 'Search suggestions' })).toHaveClass('custom-input') + }) + }) + + describe('Controls', () => { + it('should provide fallback aria labels and decorative icons when labels are omitted', async () => { + const screen = await renderAutocomplete() + + await expect.element(screen.getByRole('button', { name: 'Clear autocomplete' })).toHaveAttribute('type', 'button') + await expect.element(screen.getByRole('button', { name: 'Open autocomplete suggestions' })).toHaveAttribute('type', 'button') + expect(screen.getByRole('button', { name: 'Clear autocomplete' }).element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true') + expect(screen.getByRole('button', { name: 'Open autocomplete suggestions' }).element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should preserve explicit labels and custom children', async () => { + const screen = await renderAutocomplete({ + children: ( + + + + reset + + + open + + + ), + }) + + expect(screen.getByRole('button', { name: 'Reset search' }).element()).toContainElement(screen.getByTestId('custom-clear').element()) + expect(screen.getByRole('button', { name: 'Show suggestions' }).element()).toContainElement(screen.getByTestId('custom-trigger').element()) + expect(screen.getByRole('button', { name: 'Reset search' }).element().querySelector('.i-ri-close-line')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'Show suggestions' }).element().querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument() + }) + + it('should rely on aria-labelledby when provided instead of injecting fallback labels', async () => { + const screen = await renderAutocomplete({ + children: ( + <> + Clear from label + Trigger from label + + + + + + + ), + }) + + await expect.element(screen.getByRole('button', { name: 'Clear from label' })).not.toHaveAttribute('aria-label') + await expect.element(screen.getByRole('button', { name: 'Trigger from label' })).not.toHaveAttribute('aria-label') + }) + }) + + describe('Content and options', () => { + it('should use default overlay placement and Dify popup classes', async () => { + const screen = await renderAutocomplete({ open: true }) + + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'bottom') + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-align', 'start') + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('rounded-xl') + await expect.element(screen.getByRole('dialog', { name: 'autocomplete popup' })).toHaveClass('w-(--anchor-width)') + await expect.element(screen.getByRole('listbox', { name: 'autocomplete list' })).toHaveClass('scroll-py-1') + }) + + it('should apply custom placement side and passthrough popup props', async () => { + const onPopupClick = vi.fn() + const screen = await renderWithSafeViewport( + + + + + + + + Workflow + + + + , + ) + + asHTMLElement(screen.getByRole('dialog', { name: 'autocomplete popup' }).element()).click() + + await expect.element(screen.getByRole('group', { name: 'autocomplete positioner' })).toHaveAttribute('data-side', 'top') + expect(onPopupClick).toHaveBeenCalledTimes(1) + }) + + it('should render item text indicator status and empty wrappers with design classes', async () => { + const screen = await renderAutocomplete({ open: true }) + + await expect.element(screen.getByText('Workflow')).toHaveClass('system-sm-medium') + await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary') + await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular') + expect(screen.getByText('Workflow').element().parentElement?.querySelector('.i-ri-arrow-right-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should forward custom classes to label separator item text and indicator', async () => { + const screen = await renderWithSafeViewport( + + + + + + + + Resources + + + Workflow + + + + + + , + ) + + await expect.element(screen.getByText('Resources')).toHaveClass('custom-label') + await expect.element(screen.getByTestId('separator')).toHaveClass('custom-separator') + await expect.element(screen.getByRole('option', { name: 'Workflow' })).toHaveClass('custom-item') + await expect.element(screen.getByText('Workflow')).toHaveClass('custom-text') + await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator') + }) + }) +}) diff --git a/packages/dify-ui/src/autocomplete/index.stories.tsx b/packages/dify-ui/src/autocomplete/index.stories.tsx new file mode 100644 index 0000000000..71c7c6607d --- /dev/null +++ b/packages/dify-ui/src/autocomplete/index.stories.tsx @@ -0,0 +1,721 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import type { Virtualizer } from '@tanstack/react-virtual' +import type { RefObject } from 'react' +import { useVirtualizer } from '@tanstack/react-virtual' +import { useEffect, useMemo, useRef, useState } from 'react' +import { + Autocomplete, + AutocompleteClear, + AutocompleteCollection, + AutocompleteContent, + AutocompleteEmpty, + AutocompleteGroup, + AutocompleteInput, + AutocompleteInputGroup, + AutocompleteItem, + AutocompleteItemText, + AutocompleteLabel, + AutocompleteList, + AutocompleteSeparator, + AutocompleteStatus, + AutocompleteTrigger, + useAutocompleteFilter, + useAutocompleteFilteredItems, +} from '.' +import { cn } from '../cn' + +type Suggestion = { + value: string + label: string + description?: string + icon?: string + meta?: string +} + +type SuggestionGroup = { + label: string + items: Suggestion[] +} + +const inputWidth = 'w-80' + +type StoryVirtualizer = Virtualizer + +const scrollHighlightedVirtualItem = ( + item: unknown, + { + reason, + index, + }: { + reason: 'keyboard' | 'pointer' | 'none' + index: number + }, + virtualizer: StoryVirtualizer | null, +) => { + if (!item || !virtualizer) + return + + const isStart = index === 0 + const isEnd = index === virtualizer.options.count - 1 + const shouldScroll = reason === 'none' || (reason === 'keyboard' && (isStart || isEnd)) + + if (shouldScroll) { + queueMicrotask(() => { + virtualizer.scrollToIndex(index, { align: isEnd ? 'start' : 'end' }) + }) + } +} + +const tagSuggestions: Suggestion[] = [ + { value: 'feature', label: 'feature', description: 'Product work and launch notes' }, + { value: 'fix', label: 'fix', description: 'Bug fixes and regressions' }, + { value: 'docs', label: 'docs', description: 'Documentation updates' }, + { value: 'internal', label: 'internal', description: 'Workspace-only notes' }, + { value: 'mobile', label: 'mobile', description: 'Mobile app issues' }, + { value: 'component: autocomplete', label: 'component: autocomplete', description: 'Base UI primitive wrapper' }, + { value: 'component: combobox', label: 'component: combobox', description: 'Filterable predefined selection' }, + { value: 'component: select', label: 'component: select', description: 'Compact predefined selection' }, +] + +const promptCompletions: Suggestion[] = [ + { value: 'summarize this conversation', label: 'summarize this conversation' }, + { value: 'summarize this dataset with citations', label: 'summarize this dataset with citations' }, + { value: 'summarize this workflow run for an operator', label: 'summarize this workflow run for an operator' }, + { value: 'summarize this support ticket in 3 bullets', label: 'summarize this support ticket in 3 bullets' }, +] + +const workflowSuggestions: Suggestion[] = [ + { value: 'http-request', label: 'HTTP Request', description: 'Call an external API', icon: 'i-ri-global-line', meta: 'Tool' }, + { value: 'knowledge-retrieval', label: 'Knowledge Retrieval', description: 'Search configured datasets', icon: 'i-ri-database-2-line', meta: 'Tool' }, + { value: 'code-execution', label: 'Code Execution', description: 'Run sandboxed snippets', icon: 'i-ri-code-s-slash-line', meta: 'Tool' }, + { value: 'template-transform', label: 'Template Transform', description: 'Compose variables into output', icon: 'i-ri-braces-line', meta: 'Tool' }, + { value: 'question-classifier', label: 'Question Classifier', description: 'Route by intent', icon: 'i-ri-git-branch-line', meta: 'Tool' }, + { value: 'parameter-extractor', label: 'Parameter Extractor', description: 'Extract typed values', icon: 'i-ri-list-check-3', meta: 'Tool' }, + { value: 'answer-node', label: 'Answer Node', description: 'Return a final assistant answer', icon: 'i-ri-message-3-line', meta: 'Node' }, + { value: 'iteration-node', label: 'Iteration Node', description: 'Run a loop over array items', icon: 'i-ri-repeat-line', meta: 'Node' }, + { value: 'variable-assigner', label: 'Variable Assigner', description: 'Persist intermediate state', icon: 'i-ri-pencil-ruler-2-line', meta: 'Node' }, +] + +const groupedSuggestions: SuggestionGroup[] = [ + { + label: 'Tags', + items: tagSuggestions.slice(0, 5), + }, + { + label: 'Workflow Suggestions', + items: workflowSuggestions.slice(0, 5), + }, + { + label: 'Prompt Starters', + items: promptCompletions.slice(0, 3), + }, +] + +const commandGroups: SuggestionGroup[] = [ + { + label: 'App', + items: [ + { value: '/run', label: 'Run workflow', description: 'Execute the current draft', icon: 'i-ri-play-circle-line' }, + { value: '/publish', label: 'Publish app', description: 'Ship the current configuration', icon: 'i-ri-upload-cloud-2-line' }, + { value: '/trace', label: 'Open trace', description: 'Inspect the latest workflow run', icon: 'i-ri-route-line' }, + ], + }, + { + label: 'Workspace', + items: [ + { value: '/dataset', label: 'Search datasets', description: 'Find knowledge attached to this app', icon: 'i-ri-database-line' }, + { value: '/members', label: 'Invite members', description: 'Open workspace access settings', icon: 'i-ri-user-add-line' }, + { value: '/usage', label: 'View usage', description: 'Open model and workflow usage', icon: 'i-ri-bar-chart-line' }, + ], + }, +] + +const remoteSuggestions: Suggestion[] = [ + { value: 'agent-builder', label: 'Agent Builder', description: 'Workspace app' }, + { value: 'agent-observability', label: 'Agent Observability', description: 'Dataset' }, + { value: 'agent-routing-dataset', label: 'Agent Routing Dataset', description: 'Knowledge source' }, +] + +const virtualizedSuggestions: Suggestion[] = Array.from({ length: 1000 }, (_, index) => { + const family = ['workflow', 'dataset', 'prompt', 'tool'][index % 4]! + const number = new Intl.NumberFormat('en-US', { + minimumIntegerDigits: 4, + }).format(index + 1) + + return { + value: `${family}-${index + 1}`, + label: `${family} suggestion ${number}`, + description: `Free-form autocomplete result from ${family} search`, + icon: family === 'dataset' + ? 'i-ri-database-2-line' + : family === 'prompt' + ? 'i-ri-text-snippet' + : family === 'tool' + ? 'i-ri-tools-line' + : 'i-ri-flow-chart', + meta: family, + } +}) + +const getSuggestionLabel = (item: Suggestion) => item.label + +const SuggestionItem = ({ + item, + index, + dense, +}: { + item: Suggestion + index?: number + dense?: boolean +}) => ( + + {item.icon && +) + +const TagSuggestionItem = ({ + item, + index, +}: { + item: Suggestion + index?: number +}) => ( + + {item.label} + {item.description && {item.description}} + +) + +const BasicTagAutocomplete = ({ + size = 'medium', +}: { + size?: 'small' | 'medium' | 'large' +}) => ( + + + + + + {(item: Suggestion, index: number) => ( + + )} + + No tag suggestion. Keep the typed value. + + +) + +const GroupedSuggestionList = () => { + const groups = useAutocompleteFilteredItems() + + return ( + + {groups.map((group, groupIndex) => ( + + {groupIndex > 0 && } + {group.label} + + {(item: Suggestion) => ( + + )} + + + ))} + + ) +} + +const CommandPaletteList = () => { + const groups = useAutocompleteFilteredItems() + + return ( + + {groups.map((group, groupIndex) => ( + + {groupIndex > 0 && } + {group.label} + + {(item: Suggestion) => ( + + + {item.icon && + + Enter + + + )} + + + ))} + + ) +} + +const LimitedStatus = ({ + total, +}: { + total: number +}) => { + const items = useAutocompleteFilteredItems() + const hidden = Math.max(0, total - items.length) + + return hidden > 0 + ? `${hidden} more suggestions hidden. Refine the query to narrow results.` + : `${items.length} suggestions available.` +} + +const AsyncSearchDemo = () => { + const [value, setValue] = useState('agent') + const [loading, setLoading] = useState(false) + const [items, setItems] = useState(remoteSuggestions) + + useEffect(() => { + setLoading(true) + const timeout = window.setTimeout(() => { + setItems( + value.trim() + ? remoteSuggestions.filter(item => item.label.toLowerCase().includes(value.trim().toLowerCase())) + : remoteSuggestions, + ) + setLoading(false) + }, 500) + + return () => window.clearTimeout(timeout) + }, [value]) + + return ( +
+ + + + + + {loading ? 'Loading suggestions…' : `${items.length} remote suggestions`} + + + {(item: Suggestion, index: number) => ( + + )} + + No remote suggestion. Keep the typed query. + + +
+ ) +} + +const VirtualizedSuggestionList = ({ + virtualizerRef, +}: { + virtualizerRef: RefObject +}) => { + const scrollRef = useRef(null) + const filteredItems = useAutocompleteFilteredItems() + const virtualizer = useVirtualizer({ + count: filteredItems.length, + getScrollElement: () => scrollRef.current, + estimateSize: () => 44, + overscan: 6, + }) + + useEffect(() => { + virtualizerRef.current = virtualizer + + return () => { + virtualizerRef.current = null + } + }, [virtualizer, virtualizerRef]) + + return ( +
+ + {virtualizer.getVirtualItems().map((virtualItem) => { + const item = filteredItems[virtualItem.index] + + if (!item) + return null + + return ( +
+ +
+ ) + })} +
+
+ ) +} + +const VirtualizedStatus = () => { + const filteredItems = useAutocompleteFilteredItems() + + return ( + + {filteredItems.length} + {' '} + matching suggestions. Selecting one only replaces the input text. + + ) +} + +const FuzzyHighlight = ({ + text, + query, +}: { + text: string + query: string +}) => { + const parts = useMemo(() => { + const trimmed = query.trim() + + if (!trimmed) + return [text] + + const escaped = trimmed.slice(0, 80).replace(/[.*+?^${}()|[\]\\]/g, '\\$&') + return text.split(new RegExp(`(${escaped})`, 'i')) + }, [query, text]) + + return ( + <> + {parts.map((part, index) => ( + part.toLowerCase() === query.trim().toLowerCase() + ? {part} + : part + ))} + + ) +} + +const FuzzyMatchingDemo = () => { + const [value, setValue] = useState('retr') + const { contains } = useAutocompleteFilter({ sensitivity: 'base' }) + + return ( +
+ + + + + + {(item: Suggestion, index: number) => ( + + {item.icon && + )} + + No workflow suggestion. Keep typing freely. + + +
+ ) +} + +const meta = { + title: 'Base/UI/Autocomplete', + component: Autocomplete, + parameters: { + layout: 'centered', + docs: { + description: { + component: 'Compound autocomplete built on Base UI Autocomplete. Use it for free-form inputs where suggestions can replace or complete the typed text, but selection is not persistent state.', + }, + }, + }, + tags: ['autodocs'], +} satisfies Meta + +export default meta +type Story = StoryObj + +export const SearchTags: Story = { + render: () => ( +
+ +
+ ), +} + +export const Sizes: Story = { + render: () => ( +
+ {(['small', 'medium', 'large'] as const).map(size => ( +
+ +
+ ))} +
+ ), +} + +export const InlineAutocomplete: Story = { + render: () => ( +
+ + + + + + {(item: Suggestion, index: number) => ( + + )} + + No inline completion. Continue typing freely. + + +
+ ), +} + +export const GroupedSuggestions: Story = { + render: () => ( +
+ + + + + + No suggestion. Use the text as entered. + + +
+ ), +} + +export const FuzzyMatching: Story = { + render: () => , +} + +export const LimitResults: Story = { + render: () => ( +
+ + + + + + + + + {(item: Suggestion, index: number) => ( + + )} + + No suggestion. Submit the typed text instead. + + +
+ ), +} + +export const CommandPalette: Story = { + render: () => ( +
+ + + + + +
+ ), +} + +const VirtualizedLongSuggestionsDemo = () => { + const virtualizerRef = useRef(null) + + return ( +
+ { + scrollHighlightedVirtualItem(item, details, virtualizerRef.current) + }} + > + + + + + + No suggestion. Free-form text is still valid. + + +
+ ) +} + +export const VirtualizedLongSuggestions: Story = { + render: () => , +} + +export const AsyncSearch: Story = { + render: () => , +} + +export const Empty: Story = { + render: () => ( +
+ + + + + + {(item: Suggestion, index: number) => ( + + )} + + No tag suggestion. The custom text remains valid. + + +
+ ), +} + +export const DisabledAndReadOnly: Story = { + render: () => ( +
+ + + + + + + + + {(item: Suggestion, index: number) => ( + + )} + + + + + + + + + + + + {(item: Suggestion, index: number) => ( + + )} + + + +
+ ), +} diff --git a/packages/dify-ui/src/autocomplete/index.tsx b/packages/dify-ui/src/autocomplete/index.tsx new file mode 100644 index 0000000000..16c4b19673 --- /dev/null +++ b/packages/dify-ui/src/autocomplete/index.tsx @@ -0,0 +1,381 @@ +'use client' + +import type { VariantProps } from 'class-variance-authority' +import type { HTMLAttributes, ReactNode } from 'react' +import type { Placement } from '../placement' +import { Autocomplete as BaseAutocomplete } from '@base-ui/react/autocomplete' +import { cva } from 'class-variance-authority' +import { cn } from '../cn' +import { + overlayIndicatorClassName, + overlayLabelClassName, + overlayPopupAnimationClassName, + overlaySeparatorClassName, +} from '../overlay-shared' +import { parsePlacement } from '../placement' + +export type { Placement } + +export const Autocomplete = BaseAutocomplete.Root +export const AutocompleteValue = BaseAutocomplete.Value +export const AutocompleteGroup = BaseAutocomplete.Group +export const AutocompleteCollection = BaseAutocomplete.Collection +export const AutocompleteRow = BaseAutocomplete.Row +export const useAutocompleteFilter = BaseAutocomplete.useFilter +export const useAutocompleteFilteredItems = BaseAutocomplete.useFilteredItems + +export type AutocompleteRootProps = BaseAutocomplete.Root.Props +export type AutocompleteRootChangeEventDetails = BaseAutocomplete.Root.ChangeEventDetails +export type AutocompleteRootHighlightEventDetails = BaseAutocomplete.Root.HighlightEventDetails + +const autocompletePopupClassName = [ + 'w-(--anchor-width) max-w-[min(28rem,var(--available-width))] overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg outline-hidden', + 'data-side-top:origin-bottom data-side-bottom:origin-top data-side-left:origin-right data-side-right:origin-left', +] + +const autocompleteListClassName = [ + 'max-h-[min(20rem,var(--available-height))] overflow-y-auto overflow-x-hidden overscroll-contain p-1 outline-hidden scroll-py-1', + 'data-empty:max-h-none data-empty:p-0', +] + +const autocompleteItemClassName = [ + 'mx-1 flex min-h-8 cursor-pointer select-none items-center gap-2 rounded-lg px-2 py-1.5 text-text-secondary outline-hidden transition-colors', + 'hover:bg-state-base-hover-alt hover:text-text-primary', + 'data-highlighted:bg-state-base-hover data-highlighted:text-text-primary', + 'data-disabled:cursor-not-allowed data-disabled:opacity-30 data-disabled:hover:bg-transparent data-disabled:hover:text-text-secondary', + 'motion-reduce:transition-none', +] + +const autocompleteInputGroupVariants = cva( + [ + 'group/autocomplete flex w-full min-w-0 items-center border border-transparent bg-components-input-bg-normal text-components-input-text-filled shadow-none outline-hidden transition-[background-color,border-color,box-shadow]', + 'hover:border-components-input-border-hover hover:bg-components-input-bg-hover', + 'focus-within:border-components-input-border-active focus-within:bg-components-input-bg-active focus-within:shadow-xs', + 'data-focused:border-components-input-border-active data-focused:bg-components-input-bg-active data-focused:shadow-xs', + 'data-disabled:cursor-not-allowed data-disabled:border-transparent data-disabled:bg-components-input-bg-disabled data-disabled:text-components-input-text-filled-disabled', + 'data-disabled:hover:border-transparent data-disabled:hover:bg-components-input-bg-disabled', + 'data-readonly:shadow-none data-readonly:hover:border-transparent data-readonly:hover:bg-components-input-bg-normal', + 'motion-reduce:transition-none', + ], + { + variants: { + size: { + small: 'h-6 rounded-md', + medium: 'h-8 rounded-lg', + large: 'h-9 rounded-[10px]', + }, + }, + defaultVariants: { + size: 'medium', + }, + }, +) + +export type AutocompleteSize = NonNullable['size']> + +export type AutocompleteInputGroupProps + = BaseAutocomplete.InputGroup.Props + & VariantProps + +export function AutocompleteInputGroup({ + className, + size = 'medium', + ...props +}: AutocompleteInputGroupProps) { + return ( + + ) +} + +const autocompleteInputVariants = cva( + [ + 'w-0 min-w-0 flex-1 appearance-none border-0 bg-transparent text-components-input-text-filled caret-primary-600 outline-hidden', + 'placeholder:text-components-input-text-placeholder', + 'disabled:cursor-not-allowed disabled:text-components-input-text-filled-disabled disabled:placeholder:text-components-input-text-disabled', + 'data-readonly:cursor-default', + ], + { + variants: { + size: { + small: 'px-2 py-1 system-xs-regular', + medium: 'px-3 py-[7px] system-sm-regular', + large: 'px-4 py-2 system-md-regular', + }, + }, + defaultVariants: { + size: 'medium', + }, + }, +) + +export type AutocompleteInputProps + = Omit + & VariantProps + +export function AutocompleteInput({ + className, + size = 'medium', + type = 'text', + autoComplete = 'off', + ...props +}: AutocompleteInputProps) { + return ( + + ) +} + +const autocompleteControlVariants = cva( + [ + 'flex shrink-0 touch-manipulation items-center justify-center rounded-md text-text-tertiary outline-hidden transition-colors', + 'hover:bg-components-input-bg-hover hover:text-text-secondary focus-visible:bg-components-input-bg-hover focus-visible:text-text-secondary', + 'focus-visible:ring-1 focus-visible:ring-components-input-border-active focus-visible:ring-inset', + 'disabled:cursor-not-allowed disabled:hover:bg-transparent disabled:hover:text-text-tertiary disabled:focus-visible:bg-transparent disabled:focus-visible:ring-0', + 'group-data-disabled/autocomplete:cursor-not-allowed group-data-disabled/autocomplete:hover:bg-transparent group-data-disabled/autocomplete:focus-visible:bg-transparent group-data-disabled/autocomplete:focus-visible:ring-0', + 'group-data-readonly/autocomplete:hidden', + 'motion-reduce:transition-none', + ], + { + variants: { + size: { + small: 'mr-1 size-4', + medium: 'mr-1.5 size-5', + large: 'mr-2 size-5', + }, + }, + defaultVariants: { + size: 'medium', + }, + }, +) + +export type AutocompleteControlProps + = Omit + & VariantProps + & { className?: string } + +export function AutocompleteTrigger({ + className, + children, + size = 'medium', + type = 'button', + ...props +}: AutocompleteControlProps) { + return ( + + {children ?? + ) +} + +export type AutocompleteClearProps + = Omit + & VariantProps + & { className?: string } + +export function AutocompleteClear({ + className, + children, + size = 'medium', + type = 'button', + ...props +}: AutocompleteClearProps) { + return ( + + {children ?? + ) +} + +export function AutocompleteIcon({ + className, + children, + ...props +}: BaseAutocomplete.Icon.Props) { + return ( + + {children ?? + ) +} + +type AutocompleteContentProps = { + children: ReactNode + placement?: Placement + sideOffset?: number + alignOffset?: number + className?: string + popupClassName?: string + portalProps?: Omit + positionerProps?: Omit< + BaseAutocomplete.Positioner.Props, + 'children' | 'className' | 'side' | 'align' | 'sideOffset' | 'alignOffset' + > + popupProps?: Omit< + BaseAutocomplete.Popup.Props, + 'children' | 'className' + > +} + +export function AutocompleteContent({ + children, + placement = 'bottom-start', + sideOffset = 4, + alignOffset = 0, + className, + popupClassName, + portalProps, + positionerProps, + popupProps, +}: AutocompleteContentProps) { + const { side, align } = parsePlacement(placement) + + return ( + + + + {children} + + + + ) +} + +export function AutocompleteList({ + className, + ...props +}: BaseAutocomplete.List.Props) { + return ( + + ) +} + +export function AutocompleteItem({ + className, + ...props +}: BaseAutocomplete.Item.Props) { + return ( + + ) +} + +export type AutocompleteItemTextProps = HTMLAttributes + +export function AutocompleteItemText({ + className, + ...props +}: AutocompleteItemTextProps) { + return ( + + ) +} + +export function AutocompleteLabel({ + className, + ...props +}: BaseAutocomplete.GroupLabel.Props) { + return ( + + ) +} + +export function AutocompleteSeparator({ + className, + ...props +}: BaseAutocomplete.Separator.Props) { + return ( + + ) +} + +export function AutocompleteEmpty({ + className, + ...props +}: BaseAutocomplete.Empty.Props) { + return ( + + ) +} + +export function AutocompleteStatus({ + className, + ...props +}: BaseAutocomplete.Status.Props) { + return ( + + ) +} + +export function AutocompleteItemIndicator({ + className, + children, + ...props +}: HTMLAttributes) { + return ( + + {children ?? + ) +} diff --git a/packages/dify-ui/src/combobox/__tests__/index.spec.tsx b/packages/dify-ui/src/combobox/__tests__/index.spec.tsx new file mode 100644 index 0000000000..705ebe9601 --- /dev/null +++ b/packages/dify-ui/src/combobox/__tests__/index.spec.tsx @@ -0,0 +1,363 @@ +import type { ReactNode } from 'react' +import { render } from 'vitest-browser-react' +import { + Combobox, + ComboboxChip, + ComboboxChipRemove, + ComboboxChips, + ComboboxClear, + ComboboxContent, + ComboboxEmpty, + ComboboxGroup, + ComboboxGroupLabel, + ComboboxInput, + ComboboxInputGroup, + ComboboxInputTrigger, + ComboboxItem, + ComboboxItemIndicator, + ComboboxItemText, + ComboboxLabel, + ComboboxList, + ComboboxSeparator, + ComboboxStatus, + ComboboxTrigger, + ComboboxValue, +} from '../index' + +const renderWithSafeViewport = (ui: ReactNode) => render( +
+ {ui} +
, +) + +const asHTMLElement = (element: HTMLElement | SVGElement) => element as HTMLElement + +const renderSelectLikeCombobox = ({ + children, + open = false, +}: { + children?: ReactNode + open?: boolean +} = {}) => renderWithSafeViewport( + + {children ?? ( + <> + Resource type + + + + + 2 options + + + Workflow + + + + Dataset + + + No options + + + )} + , +) + +const renderInputCombobox = ({ + children, + open = false, +}: { + children?: ReactNode + open?: boolean +} = {}) => renderWithSafeViewport( + + {children ?? ( + <> + + + + + + + + + Workflow + + + + + + )} + , +) + +describe('Combobox wrappers', () => { + describe('Select-like trigger', () => { + it('should render label and apply medium trigger classes by default', async () => { + const screen = await renderSelectLikeCombobox() + + await expect.element(screen.getByText('Resource type')).toHaveClass('system-sm-medium') + await expect.element(screen.getByRole('combobox', { name: 'Resource type' })).toHaveClass('rounded-lg') + await expect.element(screen.getByRole('combobox', { name: 'Resource type' })).toHaveClass('system-sm-regular') + }) + + it('should apply small and large trigger size variants', async () => { + const smallScreen = await renderSelectLikeCombobox({ + children: ( + + + + ), + }) + + await expect.element(smallScreen.getByRole('combobox', { name: 'Small resource type' })).toHaveClass('rounded-md') + await expect.element(smallScreen.getByRole('combobox', { name: 'Small resource type' })).toHaveClass('system-xs-regular') + + const largeScreen = await renderSelectLikeCombobox({ + children: ( + + + + ), + }) + + await expect.element(largeScreen.getByRole('combobox', { name: 'Large resource type' })).toHaveClass('rounded-[10px]') + await expect.element(largeScreen.getByRole('combobox', { name: 'Large resource type' })).toHaveClass('system-md-regular') + }) + + it('should render default trigger icon and support hiding it', async () => { + const withIcon = await renderSelectLikeCombobox() + + expect(withIcon.getByTestId('trigger').element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true') + + const withoutIcon = await renderSelectLikeCombobox({ + children: ( + + + + ), + }) + + expect(withoutIcon.getByRole('combobox', { name: 'Resource type without icon' }).element().querySelector('.i-ri-arrow-down-s-line')).not.toBeInTheDocument() + }) + }) + + describe('Input group and controls', () => { + it('should apply medium input group and input classes by default', async () => { + const screen = await renderInputCombobox() + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-lg') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('px-3') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('system-sm-regular') + }) + + it('should apply large input group and input classes when large size is provided', async () => { + const screen = await renderInputCombobox({ + children: ( + + + + ), + }) + + await expect.element(screen.getByTestId('input-group')).toHaveClass('rounded-[10px]') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('px-4') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('system-md-regular') + }) + + it('should set input defaults and forward passthrough props', async () => { + const screen = await renderInputCombobox({ + children: ( + + + + ), + }) + + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('autocomplete', 'off') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('type', 'text') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveAttribute('placeholder', 'Find a resource') + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toBeRequired() + await expect.element(screen.getByRole('combobox', { name: 'Search resources' })).toHaveClass('custom-input') + }) + + it('should provide fallback aria labels and decorative icons for input controls', async () => { + const screen = await renderInputCombobox() + + await expect.element(screen.getByRole('button', { name: 'Clear combobox' })).toHaveAttribute('type', 'button') + await expect.element(screen.getByRole('button', { name: 'Open combobox options' })).toHaveAttribute('type', 'button') + expect(screen.getByRole('button', { name: 'Clear combobox' }).element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true') + expect(screen.getByRole('button', { name: 'Open combobox options' }).element().querySelector('.i-ri-arrow-down-s-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should rely on aria-labelledby when provided instead of injecting fallback labels', async () => { + const screen = await renderInputCombobox({ + children: ( + <> + Clear from label + Trigger from label + + + + + + + ), + }) + + await expect.element(screen.getByRole('button', { name: 'Clear from label' })).not.toHaveAttribute('aria-label') + await expect.element(screen.getByRole('button', { name: 'Trigger from label' })).not.toHaveAttribute('aria-label') + }) + }) + + describe('Content and options', () => { + it('should use default overlay placement and Dify popup classes', async () => { + const screen = await renderSelectLikeCombobox({ open: true }) + + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'bottom') + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-align', 'start') + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveClass('z-1002') + await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('rounded-xl') + await expect.element(screen.getByRole('dialog', { name: 'combobox popup' })).toHaveClass('w-(--anchor-width)') + await expect.element(screen.getByRole('listbox', { name: 'combobox list' })).toHaveClass('scroll-py-1') + }) + + it('should apply custom placement side and passthrough popup props', async () => { + const onPopupClick = vi.fn() + const screen = await renderWithSafeViewport( + + + + + + + + Workflow + + + + , + ) + + asHTMLElement(screen.getByRole('dialog', { name: 'combobox popup' }).element()).click() + + await expect.element(screen.getByRole('group', { name: 'combobox positioner' })).toHaveAttribute('data-side', 'top') + expect(onPopupClick).toHaveBeenCalledTimes(1) + }) + + it('should render item text indicator status and empty wrappers with design classes', async () => { + const screen = await renderSelectLikeCombobox({ open: true }) + + await expect.element(screen.getByTestId('list').getByText('Workflow')).toHaveClass('system-sm-medium') + await expect.element(screen.getByTestId('status')).toHaveClass('text-text-tertiary') + await expect.element(screen.getByTestId('empty')).toHaveClass('system-sm-regular') + expect(screen.getByTestId('list').getByText('Workflow').element().parentElement?.querySelector('.i-ri-check-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should forward custom classes to group label separator item text and indicator', async () => { + const screen = await renderWithSafeViewport( + + + + + + + + Resources + + + Workflow + + + + + + , + ) + + await expect.element(screen.getByText('Resources')).toHaveClass('custom-label') + await expect.element(screen.getByTestId('separator')).toHaveClass('custom-separator') + await expect.element(screen.getByRole('option', { name: 'Workflow' })).toHaveClass('custom-item') + await expect.element(screen.getByTestId('custom-list').getByText('Workflow')).toHaveClass('custom-text') + await expect.element(screen.getByTestId('indicator')).toHaveClass('custom-indicator') + }) + }) + + describe('Multiple selection chips', () => { + it('should render chip wrappers and default remove button label', async () => { + const screen = await renderWithSafeViewport( + + + + {(selectedValue: string[]) => ( + + {selectedValue.map(item => ( + + {item} + + + ))} + + )} + + + + , + ) + + await expect.element(screen.getByTestId('chips')).toHaveClass('custom-chips') + await expect.element(screen.getByText('maya').element().parentElement!).toHaveClass('custom-chip') + await expect.element(screen.getByRole('button', { name: 'Remove selected item' })).toHaveAttribute('type', 'button') + expect(screen.getByTestId('remove-chip').element().querySelector('.i-ri-close-line')).toHaveAttribute('aria-hidden', 'true') + }) + + it('should preserve chip remove aria-labelledby over fallback label', async () => { + const screen = await renderWithSafeViewport( + + + + {(selectedValue: string[]) => ( + + {selectedValue.map(item => ( + + Remove Maya + + + ))} + + )} + + + + , + ) + + await expect.element(screen.getByRole('button', { name: 'Remove Maya' })).not.toHaveAttribute('aria-label') + }) + }) +}) diff --git a/packages/dify-ui/src/combobox/index.stories.tsx b/packages/dify-ui/src/combobox/index.stories.tsx new file mode 100644 index 0000000000..f2b5f4d4c6 --- /dev/null +++ b/packages/dify-ui/src/combobox/index.stories.tsx @@ -0,0 +1,618 @@ +import type { Meta, StoryObj } from '@storybook/react-vite' +import type { Virtualizer } from '@tanstack/react-virtual' +import type { RefObject } from 'react' +import { useVirtualizer } from '@tanstack/react-virtual' +import { useEffect, useRef, useState } from 'react' +import { + Combobox, + ComboboxChip, + ComboboxChipRemove, + ComboboxChips, + ComboboxClear, + ComboboxCollection, + ComboboxContent, + ComboboxEmpty, + ComboboxGroup, + ComboboxGroupLabel, + ComboboxInput, + ComboboxInputGroup, + ComboboxInputTrigger, + ComboboxItem, + ComboboxItemIndicator, + ComboboxItemText, + ComboboxLabel, + ComboboxList, + ComboboxSeparator, + ComboboxStatus, + ComboboxTrigger, + ComboboxValue, + useComboboxFilteredItems, +} from '.' +import { cn } from '../cn' + +type Option = { + value: string + label: string + meta?: string + icon?: string + disabled?: boolean +} + +type OptionGroup = { + label: string + items: Option[] +} + +const fieldWidth = 'w-80' +const wideFieldWidth = 'w-[520px]' +const nativeFieldLabelClassName = 'mb-1 block text-text-secondary system-sm-medium' + +type StoryVirtualizer = Virtualizer + +const scrollHighlightedVirtualItem = ( + item: unknown, + { + reason, + index, + }: { + reason: 'keyboard' | 'pointer' | 'none' + index: number + }, + virtualizer: StoryVirtualizer | null, +) => { + if (!item || !virtualizer) + return + + const isStart = index === 0 + const isEnd = index === virtualizer.options.count - 1 + const shouldScroll = reason === 'none' || (reason === 'keyboard' && (isStart || isEnd)) + + if (shouldScroll) { + queueMicrotask(() => { + virtualizer.scrollToIndex(index, { align: isEnd ? 'start' : 'end' }) + }) + } +} + +const providerOptions: Option[] = [ + { value: 'openai', label: 'OpenAI', meta: 'GPT-5, GPT-4.1', icon: 'i-ri-openai-fill' }, + { value: 'anthropic', label: 'Anthropic', meta: 'Claude Opus, Sonnet', icon: 'i-ri-sparkling-2-line' }, + { value: 'google', label: 'Google', meta: 'Gemini 2.5', icon: 'i-ri-google-fill' }, + { value: 'azure-openai', label: 'Azure OpenAI', meta: 'Enterprise workspace', icon: 'i-ri-microsoft-fill' }, + { value: 'localai', label: 'LocalAI', meta: 'Self-hosted endpoint', icon: 'i-ri-server-line', disabled: true }, +] + +const dataSourceOptions: Option[] = [ + { value: 'knowledge-base', label: 'Knowledge Base', meta: 'Vector index', icon: 'i-ri-database-2-line' }, + { value: 'notion', label: 'Notion', meta: 'Synced pages', icon: 'i-ri-notion-fill' }, + { value: 'website', label: 'Website crawler', meta: 'Public URLs', icon: 'i-ri-global-line' }, + { value: 's3', label: 'S3 bucket', meta: 'Private files', icon: 'i-ri-cloud-line' }, + { value: 'slack', label: 'Slack', meta: 'Channel history', icon: 'i-ri-slack-fill' }, +] + +const reviewerOptions: Option[] = [ + { value: 'maya', label: 'Maya Chen', meta: 'Product owner' }, + { value: 'liam', label: 'Liam Brooks', meta: 'Prompt engineer' }, + { value: 'nora', label: 'Nora Park', meta: 'Data steward' }, + { value: 'owen', label: 'Owen Reed', meta: 'Security reviewer' }, + { value: 'yuki', label: 'Yuki Tanaka', meta: 'ML engineer' }, +] + +const toolGroups: OptionGroup[] = [ + { + label: 'Retrieval', + items: [ + { value: 'dataset-search', label: 'Dataset search', meta: 'Search workspace knowledge', icon: 'i-ri-search-eye-line' }, + { value: 'web-scraper', label: 'Web scraper', meta: 'Fetch public pages', icon: 'i-ri-global-line' }, + ], + }, + { + label: 'Actions', + items: [ + { value: 'http-request', label: 'HTTP request', meta: 'Call external APIs', icon: 'i-ri-terminal-box-line' }, + { value: 'code-runner', label: 'Code runner', meta: 'Execute sandboxed scripts', icon: 'i-ri-code-s-slash-line' }, + ], + }, + { + label: 'Operations', + items: [ + { value: 'human-review', label: 'Human review', meta: 'Assign approval task', icon: 'i-ri-user-voice-line' }, + { value: 'audit-log', label: 'Audit log', meta: 'Record workflow events', icon: 'i-ri-file-list-3-line' }, + ], + }, +] + +const tagOptions: Option[] = [ + { value: 'rag', label: 'RAG' }, + { value: 'agent', label: 'Agent' }, + { value: 'production', label: 'Production' }, + { value: 'evaluation', label: 'Evaluation' }, + { value: 'finance', label: 'Finance' }, + { value: 'support', label: 'Support' }, +] + +const directoryOptions: Option[] = [ + { value: 'maya-chen', label: 'Maya Chen', meta: 'Product owner · maya@example.com', icon: 'i-ri-user-3-line' }, + { value: 'liam-brooks', label: 'Liam Brooks', meta: 'Prompt engineer · liam@example.com', icon: 'i-ri-user-3-line' }, + { value: 'nora-park', label: 'Nora Park', meta: 'Data steward · nora@example.com', icon: 'i-ri-user-3-line' }, + { value: 'owen-reed', label: 'Owen Reed', meta: 'Security reviewer · owen@example.com', icon: 'i-ri-shield-user-line' }, + { value: 'yuki-tanaka', label: 'Yuki Tanaka', meta: 'ML engineer · yuki@example.com', icon: 'i-ri-user-3-line' }, + { value: 'ava-martin', label: 'Ava Martin', meta: 'Support lead · ava@example.com', icon: 'i-ri-customer-service-2-line' }, +] + +const emptyOptions: Option[] = [ + { value: 'billing', label: 'Billing connector' }, + { value: 'zendesk', label: 'Zendesk' }, + { value: 'github', label: 'GitHub issues' }, +] + +const modelCatalogOptions: Option[] = Array.from({ length: 1000 }, (_, index) => { + const provider = ['OpenAI', 'Anthropic', 'Google', 'Mistral', 'DeepSeek'][index % 5]! + const family = ['chat', 'reasoning', 'vision', 'embedding'][index % 4]! + const number = new Intl.NumberFormat('en-US', { + minimumIntegerDigits: 4, + }).format(index + 1) + + return { + value: `model-${index + 1}`, + label: `${provider} ${family} ${number}`, + meta: `${provider} provider · ${family}`, + icon: family === 'embedding' + ? 'i-ri-vector-triangle' + : family === 'vision' + ? 'i-ri-image-circle-line' + : family === 'reasoning' + ? 'i-ri-brain-line' + : 'i-ri-chat-1-line', + } +}) + +const sizeOptions: Option[] = providerOptions.slice(0, 3) +const defaultProvider = providerOptions[0]! +const disabledProvider = providerOptions[1]! +const defaultDataSource = dataSourceOptions[0]! +const defaultPopupDataSource = dataSourceOptions[1]! +const readOnlyDataSource = dataSourceOptions[2]! +const defaultTool = toolGroups[0]!.items[0]! +const defaultReviewers = [reviewerOptions[0]!, reviewerOptions[2]!] +const defaultTag = tagOptions[2]! + +const renderOptionItem = (option: Option, index?: number) => ( + + + {option.icon && } + + {option.label} + {option.meta && {option.meta}} + + + + +) + +const renderSimpleOptionItem = (option: Option, index?: number) => ( + + {option.label} + + +) + +const PopupSearchInput = ({ + label, + placeholder, +}: { + label: string + placeholder: string +}) => ( + + + + + +) + +const GroupedToolList = () => { + const groups = useComboboxFilteredItems() + + return ( + + {groups.map((group, groupIndex) => ( + + {groupIndex > 0 && } + {group.label} + + {(option: Option) => renderOptionItem(option)} + + + ))} + + ) +} + +const VirtualizedModelList = ({ + virtualizerRef, +}: { + virtualizerRef: RefObject +}) => { + const scrollRef = useRef(null) + const filteredItems = useComboboxFilteredItems