From 8fd616d27f014903588aaf7647695becd8709e0d Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 7 May 2026 12:46:23 +0900 Subject: [PATCH] refactor: add type to test (#30873) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/extensions/ext_session_factory.py | 4 +- .../conftest.py | 6 +- .../controllers/console/app/test_app_apis.py | 49 ++++++----- .../console/app/test_app_import_api.py | 21 ++--- .../console/auth/test_email_register.py | 9 +- .../console/auth/test_forgot_password.py | 9 +- .../controllers/console/auth/test_oauth.py | 9 +- .../console/auth/test_password_reset.py | 25 +++--- .../rag_pipeline/test_rag_pipeline.py | 9 +- .../test_rag_pipeline_datasets.py | 17 ++-- .../rag_pipeline/test_rag_pipeline_import.py | 9 +- .../test_rag_pipeline_workflow.py | 83 ++++++++--------- .../console/datasets/test_data_source.py | 11 +-- .../console/explore/test_conversation.py | 11 +-- .../console/workspace/test_tool_provider.py | 83 ++++++++--------- .../workspace/test_trigger_providers.py | 15 ++-- .../service_api/dataset/test_dataset.py | 3 +- .../controllers/web/test_conversation.py | 37 ++++---- .../web/test_web_forgot_password.py | 17 ++-- .../controllers/web/test_wraps.py | 3 +- .../layers/test_pause_state_persist_layer.py | 16 ++-- .../rag/pipeline/test_queue_integration.py | 25 +++--- .../test_dataset_retrieval_integration.py | 25 +++--- .../auth/test_api_key_auth_service.py | 47 +++++++--- .../services/auth/test_auth_integration.py | 3 +- .../services/document_service_status.py | 71 ++++++++++----- .../enterprise/test_account_deletion_sync.py | 9 +- .../recommend_app/test_database_retrieval.py | 18 ++-- .../test_advanced_prompt_template_service.py | 77 ++++++++++------ .../services/test_app_dsl_service.py | 88 +++++++++++-------- .../services/test_attachment_service.py | 8 +- .../services/test_billing_service.py | 21 ++--- .../services/test_conversation_service.py | 67 ++++++++------ .../test_conversation_service_variables.py | 31 +++---- .../test_conversation_variable_updater.py | 15 ++-- .../services/test_credit_pool_service.py | 21 ++--- .../test_dataset_permission_service.py | 41 +++++---- .../services/test_dataset_service.py | 10 +-- .../test_dataset_service_create_dataset.py | 3 +- .../test_dataset_service_delete_dataset.py | 12 +-- .../test_dataset_service_permissions.py | 5 +- .../test_delete_archived_workflow_run.py | 19 ++-- .../services/test_end_user_service.py | 45 +++++----- .../services/test_feature_service.py | 87 +++++++++++------- .../test_human_input_delivery_test_service.py | 7 +- .../services/test_metadata_partial_update.py | 9 +- .../services/test_oauth_server_service.py | 7 +- .../test_restore_archived_workflow_run.py | 5 +- .../services/test_webhook_service.py | 13 +-- .../test_webhook_service_relationships.py | 27 +++--- .../services/test_workflow_app_service.py | 8 +- .../test_workflow_draft_variable_service.py | 6 +- .../services/test_workflow_service.py | 12 +-- .../workflow/test_workflow_deletion.py | 8 +- ...kflow_node_execution_service_repository.py | 20 ++--- .../tasks/test_clean_notion_document_task.py | 24 ++--- .../test_create_segment_to_index_task.py | 69 +++++++++------ .../tasks/test_dataset_indexing_task.py | 63 ++++++++----- .../test_deal_dataset_vector_index_task.py | 33 +++---- .../test_delete_segment_from_index_task.py | 51 +++++++---- .../test_disable_segments_from_index_task.py | 25 +++--- .../tasks/test_document_indexing_sync_task.py | 29 +++--- .../tasks/test_document_indexing_task.py | 49 ++++++----- .../test_document_indexing_update_task.py | 17 ++-- .../test_duplicate_document_indexing_task.py | 51 ++++++----- .../tasks/test_mail_change_mail_task.py | 17 ++-- .../tasks/test_mail_email_code_login_task.py | 27 +++--- .../test_mail_human_input_delivery_task.py | 5 +- .../tasks/test_mail_inner_task.py | 19 ++-- .../tasks/test_mail_invite_member_task.py | 27 +++--- .../tasks/test_mail_owner_transfer_task.py | 23 ++--- .../tasks/test_mail_register_task.py | 15 ++-- .../tasks/test_rag_pipeline_run_tasks.py | 13 +-- .../test_remove_app_and_related_data_task.py | 7 +- .../test_opendal_fs_default_root.py | 8 +- .../test_workflow_pause_integration.py | 2 +- .../code_executor/test_code_javascript.py | 10 ++- .../nodes/code_executor/test_code_jinja2.py | 10 ++- .../nodes/code_executor/test_code_python3.py | 10 ++- .../console/auth/test_account_activation.py | 4 +- .../console/auth/test_email_verification.py | 10 +-- .../console/auth/test_login_logout.py | 42 ++++----- .../console/auth/test_token_refresh.py | 4 +- .../console/billing/test_billing.py | 16 ++-- .../core/tools/utils/test_parser.py | 6 +- .../unit_tests/services/controller_api.py | 29 +++--- .../generated/api/console/apps/types.gen.ts | 4 +- .../generated/api/console/apps/zod.gen.ts | 4 +- .../api/console/datasets/types.gen.ts | 32 ++++++- .../generated/api/console/datasets/zod.gen.ts | 43 ++++++++- .../generated/api/service/types.gen.ts | 32 ++++++- .../generated/api/service/zod.gen.ts | 43 ++++++++- 92 files changed, 1307 insertions(+), 882 deletions(-) diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py index 0eb43d66f4..e19ccd11e5 100644 --- a/api/extensions/ext_session_factory.py +++ b/api/extensions/ext_session_factory.py @@ -1,7 +1,9 @@ +from flask import Flask + from core.db.session_factory import configure_session_factory from extensions.ext_database import db -def init_app(app): +def init_app(app: Flask): with app.app_context(): configure_session_factory(db.engine) diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 66a25e5daf..b4482674da 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask: @pytest.fixture -def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]: +def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]: """ Request context fixture for containerized Flask application. @@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, @pytest.fixture -def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]: +def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]: """ Test client fixture for containerized Flask application. @@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli @pytest.fixture -def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]: +def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]: """ Database session fixture for containerized testing. diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 18755ef012..bb737754a1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -69,7 +70,7 @@ def _unwrap(func): class TestCompletionEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_completion_create_payload(self): @@ -86,7 +87,7 @@ class TestCompletionEndpoints: ) assert payload.query == "hi" - def test_completion_api_success(self, app, monkeypatch): + def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -116,7 +117,7 @@ class TestCompletionEndpoints: assert resp == {"result": {"text": "ok"}} - def test_completion_api_conversation_not_exists(self, app, monkeypatch): + def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -142,7 +143,7 @@ class TestCompletionEndpoints: with pytest.raises(NotFound): method(app_model=MagicMock(id="app-1")) - def test_completion_api_provider_not_initialized(self, app, monkeypatch): + def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -166,7 +167,7 @@ class TestCompletionEndpoints: with pytest.raises(completion_module.ProviderNotInitializeError): method(app_model=MagicMock(id="app-1")) - def test_completion_api_quota_exceeded(self, app, monkeypatch): + def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -193,10 +194,10 @@ class TestCompletionEndpoints: class TestAppEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppApi() method = _unwrap(api.put) payload = { @@ -234,7 +235,7 @@ class TestAppEndpoints: } ) - def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppIconApi() method = _unwrap(api.post) payload = { @@ -266,7 +267,7 @@ class TestAppEndpoints: class TestOpsTraceEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_ops_trace_query_basic(self): @@ -277,7 +278,7 @@ class TestOpsTraceEndpoints: payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"}) assert payload.tracing_config["api_key"] == "k" - def test_trace_app_config_get_empty(self, app, monkeypatch): + def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.get) @@ -292,7 +293,7 @@ class TestOpsTraceEndpoints: assert result == {"has_not_configured": True} - def test_trace_app_config_post_invalid(self, app, monkeypatch): + def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.post) @@ -309,7 +310,7 @@ class TestOpsTraceEndpoints: with pytest.raises(BadRequest): method(app_id="app-1") - def test_trace_app_config_delete_not_found(self, app, monkeypatch): + def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.delete) @@ -326,7 +327,7 @@ class TestOpsTraceEndpoints: class TestSiteEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_site_response_structure(self): @@ -337,7 +338,7 @@ class TestSiteEndpoints: payload = AppSiteUpdatePayload(default_language="en-US") assert payload.default_language == "en-US" - def test_app_site_update_post(self, app, monkeypatch): + def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSite() method = _unwrap(api.post) @@ -375,7 +376,7 @@ class TestSiteEndpoints: assert isinstance(result, dict) assert result["title"] == "My Site" - def test_app_site_access_token_reset(self, app, monkeypatch): + def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) @@ -427,7 +428,7 @@ class TestWorkflowEndpoints: class TestWorkflowAppLogEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_app_log_query(self): @@ -438,7 +439,7 @@ class TestWorkflowAppLogEndpoints: query = WorkflowAppLogQuery(detail="true") assert query.detail is True - def test_workflow_app_log_api_get(self, app, monkeypatch): + def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_app_log_module.WorkflowAppLogApi() method = _unwrap(api.get) @@ -477,14 +478,14 @@ class TestWorkflowAppLogEndpoints: class TestWorkflowDraftVariableEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_variable_creation(self): payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test") assert payload.name == "var1" - def test_workflow_variable_collection_get(self, app, monkeypatch): + def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_draft_variable_module.WorkflowVariableCollectionApi() method = _unwrap(api.get) @@ -529,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints: class TestWorkflowStatisticEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_statistic_time_range(self): @@ -541,7 +542,7 @@ class TestWorkflowStatisticEndpoints: assert query.start is None assert query.end is None - def test_workflow_daily_runs_statistic(self, app, monkeypatch): + def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -567,7 +568,7 @@ class TestWorkflowStatisticEndpoints: assert response.get_json() == {"data": [{"date": "2024-01-01"}]} - def test_workflow_daily_terminals_statistic(self, app, monkeypatch): + def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -598,7 +599,7 @@ class TestWorkflowStatisticEndpoints: class TestWorkflowTriggerEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_webhook_trigger_payload(self): @@ -608,7 +609,7 @@ class TestWorkflowTriggerEndpoints: enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True) assert enable_payload.enable_trigger is True - def test_webhook_trigger_api_get(self, app, monkeypatch): + def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_trigger_module.WebhookTriggerApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index 25d19cf35a..bcb6e41ef7 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from controllers.console.app import app_import as app_import_module from services.app_dsl_service import ImportStatus @@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: class TestAppImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -57,7 +58,7 @@ class TestAppImportApi: assert status == 400 assert response["status"] == ImportStatus.FAILED - def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -75,7 +76,7 @@ class TestAppImportApi: assert status == 202 assert response["status"] == ImportStatus.PENDING - def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -96,7 +97,7 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED - def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -121,7 +122,7 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED - def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -149,10 +150,10 @@ class TestAppImportApi: class TestAppImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportConfirmApi() method = _unwrap(api.post) @@ -172,10 +173,10 @@ class TestAppImportConfirmApi: class TestAppImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportCheckDependenciesApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 320da85b60..1fcce9ca44 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - app, + app: Flask, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False @@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi: mock_revoke, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} @@ -120,7 +121,7 @@ class TestEmailRegisterResetApi: mock_create_account, mock_login, mock_reset_login_rate, - app, + app: Flask, ): mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} mock_create_account.return_value = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index d2703ed5cc..014c1588fe 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} @@ -123,7 +124,7 @@ class TestForgotPasswordResetApi: mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 1eabb45422..01d88d247c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.mark.parametrize( @@ -65,7 +66,7 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -130,7 +131,7 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -394,7 +395,7 @@ class TestOAuthCallback: class TestAccountGeneration: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 50249bcd74..8d6b25b5b3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email.assert_called_once() @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask): """ Test password reset email blocked by IP rate limit. @@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi: mock_reset_rate_limit.assert_called_once_with("user@example.com") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask): """ Test code verification blocked by rate limit. @@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with invalid token. @@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with mismatched email. @@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with incorrect code. @@ -321,7 +322,7 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -375,7 +376,7 @@ class TestForgotPasswordResetApi: mock_revoke_token.assert_called_once_with("valid_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, app): + def test_reset_password_mismatch(self, mock_get_data, app: Flask): """ Test password reset with mismatched passwords. @@ -397,7 +398,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, app): + def test_reset_password_invalid_token(self, mock_get_data, app: Flask): """ Test password reset with invalid token. @@ -418,7 +419,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, app): + def test_reset_password_wrong_phase(self, mock_get_data, app: Flask): """ Test password reset with token not in reset phase. @@ -442,7 +443,7 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask): """ Test password reset for non-existent account. diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index d5ae95dfb7..2752e6b34f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from controllers.console import console_ns @@ -26,7 +27,7 @@ def unwrap(func): class TestPipelineTemplateListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -50,7 +51,7 @@ class TestPipelineTemplateListApi: class TestPipelineTemplateDetailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -115,7 +116,7 @@ class TestPipelineTemplateDetailApi: class TestCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_patch_success(self, app): @@ -193,7 +194,7 @@ class TestCustomizedPipelineTemplateApi: class TestPublishCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_post_success(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index 64e3de2ca3..7624c1150f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden import services @@ -24,13 +25,13 @@ def unwrap(func): class TestCreateRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _valid_payload(self): return {"yaml_content": "name: test"} - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi: assert status == 201 assert response == import_info - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(Forbidden): method(api) - def test_post_dataset_name_duplicate(self, app): + def test_post_dataset_name_duplicate(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi: class TestCreateEmptyRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) @@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi: assert status == 201 assert response == {"id": "ds-1"} - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index cb67892878..f238ca13ee 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( @@ -25,7 +26,7 @@ def unwrap(func): class TestRagPipelineImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _payload(self, mode="create"): @@ -128,7 +129,7 @@ class TestRagPipelineImportApi: class TestRagPipelineImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_confirm_success(self, app): @@ -190,7 +191,7 @@ class TestRagPipelineImportConfirmApi: class TestRagPipelineImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -219,7 +220,7 @@ class TestRagPipelineImportCheckDependenciesApi: class TestRagPipelineExportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_with_include_secret(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index c1f3122c2b..1fdb3057b8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound @@ -45,10 +46,10 @@ def unwrap(func): class TestDraftWorkflowApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_draft_success(self, app): + def test_get_draft_success(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -68,7 +69,7 @@ class TestDraftWorkflowApi: result = method(api, pipeline) assert result == workflow - def test_get_draft_not_exist(self, app): + def test_get_draft_not_exist(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -86,7 +87,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotExist): method(api, pipeline) - def test_sync_hash_not_match(self, app): + def test_sync_hash_not_match(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -111,7 +112,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotSync): method(api, pipeline) - def test_sync_invalid_text_plain(self, app): + def test_sync_invalid_text_plain(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -128,7 +129,7 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 - def test_restore_published_workflow_to_draft_success(self, app): + def test_restore_published_workflow_to_draft_success(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -155,7 +156,7 @@ class TestDraftWorkflowApi: assert result["result"] == "success" assert result["hash"] == "restored-hash" - def test_restore_published_workflow_to_draft_not_found(self, app): + def test_restore_published_workflow_to_draft_not_found(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -179,7 +180,7 @@ class TestDraftWorkflowApi: with pytest.raises(NotFound): method(api, pipeline, "published-workflow") - def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -211,10 +212,10 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_iteration_node_success(self, app): + def test_iteration_node_success(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -240,7 +241,7 @@ class TestDraftRunNodes: result = method(api, pipeline, "node") assert result == {"ok": True} - def test_iteration_node_conversation_not_exists(self, app): + def test_iteration_node_conversation_not_exists(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -262,7 +263,7 @@ class TestDraftRunNodes: with pytest.raises(NotFound): method(api, pipeline, "node") - def test_loop_node_success(self, app): + def test_loop_node_success(self, app: Flask): api = RagPipelineDraftRunLoopNodeApi() method = unwrap(api.post) @@ -290,10 +291,10 @@ class TestDraftRunNodes: class TestPipelineRunApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_draft_run_success(self, app): + def test_draft_run_success(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -325,7 +326,7 @@ class TestPipelineRunApis: ): assert method(api, pipeline) == {"ok": True} - def test_draft_run_rate_limit(self, app): + def test_draft_run_rate_limit(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -356,10 +357,10 @@ class TestPipelineRunApis: class TestDraftNodeRun: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_execution_not_found(self, app): + def test_execution_not_found(self, app: Flask): api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) @@ -387,7 +388,7 @@ class TestDraftNodeRun: class TestPublishedPipelineApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_publish_success(self, app, db_session_with_containers: Session): @@ -436,10 +437,10 @@ class TestPublishedPipelineApis: class TestMiscApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_task_stop(self, app): + def test_task_stop(self, app: Flask): api = RagPipelineTaskStopApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestMiscApis: stop_mock.assert_called_once() assert result["result"] == "success" - def test_transform_forbidden(self, app): + def test_transform_forbidden(self, app: Flask): api = RagPipelineTransformApi() method = unwrap(api.post) @@ -476,7 +477,7 @@ class TestMiscApis: with pytest.raises(Forbidden): method(api, "ds1") - def test_recommended_plugins(self, app): + def test_recommended_plugins(self, app: Flask): api = RagPipelineRecommendedPluginApi() method = unwrap(api.get) @@ -496,10 +497,10 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_published_run_success(self, app): + def test_published_run_success(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi: result = method(api, pipeline) assert result == {"ok": True} - def test_published_run_rate_limit(self, app): + def test_published_run_rate_limit(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi: class TestDefaultBlockConfigApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_block_config_success(self, app): + def test_get_block_config_success(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi: result = method(api, pipeline, "llm") assert result == {"k": "v"} - def test_get_block_config_invalid_json(self, app): + def test_get_block_config_invalid_json(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_published_workflows_success(self, app): + def test_get_published_workflows_success(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi: assert result["items"] == [{"id": "w1"}] assert result["has_more"] is False - def test_get_published_workflows_forbidden(self, app): + def test_get_published_workflows_forbidden(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi: class TestRagPipelineByIdApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -682,7 +683,7 @@ class TestRagPipelineByIdApi: assert result == workflow - def test_patch_no_fields(self, app): + def test_patch_no_fields(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -700,7 +701,7 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -720,7 +721,7 @@ class TestRagPipelineByIdApi: workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) - def test_delete_active_workflow_rejected(self, app): + def test_delete_active_workflow_rejected(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -733,10 +734,10 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_last_run_success(self, app): + def test_last_run_success(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi: result = method(api, pipeline, "node1") assert result == node_exec - def test_last_run_not_found(self, app): + def test_last_run_not_found(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_set_datasource_variables_success(self, app): + def test_set_datasource_variables_success(self, app: Flask): api = RagPipelineDatasourceVariableApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 1c4c6a899f..50ad92afa1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console.datasets import data_source @@ -51,7 +52,7 @@ def mock_engine(): class TestDataSourceApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): @@ -188,7 +189,7 @@ class TestDataSourceApi: class TestDataSourceNotionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_credential_not_found(self, app, patch_tenant): @@ -323,7 +324,7 @@ class TestDataSourceNotionListApi: class TestDataSourceNotionApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_preview_success(self, app, patch_tenant): @@ -381,7 +382,7 @@ class TestDataSourceNotionApi: class TestDataSourceNotionDatasetSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): @@ -424,7 +425,7 @@ class TestDataSourceNotionDatasetSyncApi: class TestDataSourceNotionDocumentSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 83492048ef..0b53ca5585 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.conversation as conversation_module @@ -53,7 +54,7 @@ def user(): class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, chat_app, user): @@ -108,7 +109,7 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_delete_success(self, app, chat_app, user): @@ -156,7 +157,7 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_rename_success(self, app, chat_app, user): @@ -197,7 +198,7 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_pin_success(self, app, chat_app, user): @@ -219,7 +220,7 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_unpin_success(self, app, chat_app, user): diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index f2e7104b18..d944613886 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -6,6 +6,7 @@ import json from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -60,7 +61,7 @@ def _mock_user_tenant(): @pytest.fixture -def client(flask_app_with_containers): +def client(flask_app_with_containers: Flask): return flask_app_with_containers.test_client() @@ -147,10 +148,10 @@ class TestUtils: class TestToolProviderListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ToolProviderListApi() method = unwrap(api.get) @@ -170,10 +171,10 @@ class TestToolProviderListApi: class TestBuiltinProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolBuiltinProviderListToolsApi() method = unwrap(api.get) @@ -190,7 +191,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == [{"a": 1}] - def test_info(self, app): + def test_info(self, app: Flask): api = ToolBuiltinProviderInfoApi() method = unwrap(api.get) @@ -207,7 +208,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"x": 1} - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolBuiltinProviderDeleteApi() method = unwrap(api.post) @@ -224,7 +225,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["result"] == "success" - def test_add_invalid_type(self, app): + def test_add_invalid_type(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -238,7 +239,7 @@ class TestBuiltinProviderApis: with pytest.raises(ValueError): method(api, "provider") - def test_add_success(self, app): + def test_add_success(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -257,7 +258,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["id"] == 1 - def test_update(self, app): + def test_update(self, app: Flask): api = ToolBuiltinProviderUpdateApi() method = unwrap(api.post) @@ -276,7 +277,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credentials(self, app): + def test_get_credentials(self, app: Flask): api = ToolBuiltinProviderGetCredentialsApi() method = unwrap(api.get) @@ -293,7 +294,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"k": "v"} - def test_icon(self, app): + def test_icon(self, app: Flask): api = ToolBuiltinProviderIconApi() method = unwrap(api.get) @@ -307,7 +308,7 @@ class TestBuiltinProviderApis: response = method(api, "provider") assert response.mimetype == "image/png" - def test_credentials_schema(self, app): + def test_credentials_schema(self, app: Flask): api = ToolBuiltinProviderCredentialsSchemaApi() method = unwrap(api.get) @@ -324,7 +325,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider", "oauth2") == {"schema": {}} - def test_set_default_credential(self, app): + def test_set_default_credential(self, app: Flask): api = ToolBuiltinProviderSetDefaultApi() method = unwrap(api.post) @@ -341,7 +342,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credential_info(self, app): + def test_get_credential_info(self, app: Flask): api = ToolBuiltinProviderGetCredentialInfoApi() method = unwrap(api.get) @@ -358,7 +359,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"info": "x"} - def test_get_oauth_client_schema(self, app): + def test_get_oauth_client_schema(self, app: Flask): api = ToolBuiltinProviderGetOauthClientSchemaApi() method = unwrap(api.get) @@ -378,10 +379,10 @@ class TestBuiltinProviderApis: class TestApiProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_add(self, app): + def test_add(self, app: Flask): api = ToolApiProviderAddApi() method = unwrap(api.post) @@ -406,7 +407,7 @@ class TestApiProviderApis: ): assert method(api)["id"] == 1 - def test_remote_schema(self, app): + def test_remote_schema(self, app: Flask): api = ToolApiProviderGetRemoteSchemaApi() method = unwrap(api.get) @@ -423,7 +424,7 @@ class TestApiProviderApis: ): assert method(api)["schema"] == "x" - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolApiProviderListToolsApi() method = unwrap(api.get) @@ -440,7 +441,7 @@ class TestApiProviderApis: ): assert method(api) == [{"tool": 1}] - def test_update(self, app): + def test_update(self, app: Flask): api = ToolApiProviderUpdateApi() method = unwrap(api.post) @@ -468,7 +469,7 @@ class TestApiProviderApis: ): assert method(api)["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolApiProviderDeleteApi() method = unwrap(api.post) @@ -485,7 +486,7 @@ class TestApiProviderApis: ): assert method(api)["result"] == "success" - def test_get(self, app): + def test_get(self, app: Flask): api = ToolApiProviderGetApi() method = unwrap(api.get) @@ -505,10 +506,10 @@ class TestApiProviderApis: class TestWorkflowApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create(self, app): + def test_create(self, app: Flask): api = ToolWorkflowProviderCreateApi() method = unwrap(api.post) @@ -534,7 +535,7 @@ class TestWorkflowApis: ): assert method(api)["id"] == 1 - def test_update_invalid(self, app): + def test_update_invalid(self, app: Flask): api = ToolWorkflowProviderUpdateApi() method = unwrap(api.post) @@ -560,7 +561,7 @@ class TestWorkflowApis: result = method(api) assert result["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolWorkflowProviderDeleteApi() method = unwrap(api.post) @@ -577,7 +578,7 @@ class TestWorkflowApis: ): assert method(api)["ok"] - def test_get_error(self, app): + def test_get_error(self, app: Flask): api = ToolWorkflowProviderGetApi() method = unwrap(api.get) @@ -594,10 +595,10 @@ class TestWorkflowApis: class TestLists: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_builtin_list(self, app): + def test_builtin_list(self, app: Flask): api = ToolBuiltinListApi() method = unwrap(api.get) @@ -617,7 +618,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_api_list(self, app): + def test_api_list(self, app: Flask): api = ToolApiListApi() method = unwrap(api.get) @@ -637,7 +638,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_workflow_list(self, app): + def test_workflow_list(self, app: Flask): api = ToolWorkflowListApi() method = unwrap(api.get) @@ -660,10 +661,10 @@ class TestLists: class TestLabels: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_labels(self, app): + def test_labels(self, app: Flask): api = ToolLabelsApi() method = unwrap(api.get) @@ -679,10 +680,10 @@ class TestLabels: class TestOAuth: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_no_client(self, app): + def test_oauth_no_client(self, app: Flask): api = ToolPluginOAuthApi() method = unwrap(api.get) @@ -700,7 +701,7 @@ class TestOAuth: with pytest.raises(Forbidden): method(api, "provider") - def test_oauth_callback_no_cookie(self, app): + def test_oauth_callback_no_cookie(self, app: Flask): api = ToolOAuthCallback() method = unwrap(api.get) @@ -711,10 +712,10 @@ class TestOAuth: class TestOAuthCustomClient: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_save_custom_client(self, app): + def test_save_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.post) @@ -731,7 +732,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider")["ok"] - def test_get_custom_client(self, app): + def test_get_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.get) @@ -748,7 +749,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider") == {"client_id": "x"} - def test_delete_custom_client(self, app): + def test_delete_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.delete) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index ca8195af53..6efdaf2943 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden from controllers.console.workspace.trigger_providers import ( @@ -45,7 +46,7 @@ def mock_user(): class TestTriggerProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_icon_success(self, app): @@ -93,7 +94,7 @@ class TestTriggerProviderApis: class TestTriggerSubscriptionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_list_success(self, app): @@ -128,7 +129,7 @@ class TestTriggerSubscriptionListApi: class TestTriggerSubscriptionBuilderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_create_builder(self, app): @@ -236,7 +237,7 @@ class TestTriggerSubscriptionBuilderApis: class TestTriggerSubscriptionCrud: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_update_rename_only(self, app): @@ -342,7 +343,7 @@ class TestTriggerSubscriptionCrud: class TestTriggerOAuthApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_oauth_authorize_success(self, app): @@ -480,7 +481,7 @@ class TestTriggerOAuthApis: class TestTriggerOAuthClientManageApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_client(self, app): @@ -556,7 +557,7 @@ class TestTriggerOAuthClientManageApi: class TestTriggerSubscriptionVerifyApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_verify_success(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 437b199ec2..5791d2f6e2 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -18,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound @@ -246,7 +247,7 @@ def _unwrap(method): @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): # Uses the full containerised app so that Flask config, extensions, and # blueprint registrations match production. Most tests mock the service # layer to isolate controller logic; a few (e.g. test_list_tags_from_db) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py index e1e6741014..c34da27ebe 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.web.conversation import ( @@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace: class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context("/conversations"): with pytest.raises(NotChatAppError): ConversationListApi().get(_completion_app(), _end_user()) @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") - def test_happy_path(self, mock_paginate: MagicMock, app) -> None: + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: conv_id = str(uuid4()) conv = SimpleNamespace( id=conv_id, @@ -65,16 +66,16 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}"): with pytest.raises(NotChatAppError): ConversationApi().delete(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.delete") - def test_delete_success(self, mock_delete: MagicMock, app) -> None: + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) @@ -83,7 +84,7 @@ class TestConversationApi: assert result["result"] == "success" @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) - def test_delete_not_found(self, mock_delete: MagicMock, app) -> None: + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): with pytest.raises(NotFound, match="Conversation Not Exists"): @@ -92,17 +93,17 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): with pytest.raises(NotChatAppError): ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.rename") @patch("controllers.web.conversation.web_ns") - def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "New Name", "auto_generate": False} conv = SimpleNamespace( @@ -126,7 +127,7 @@ class TestConversationRenameApi: side_effect=ConversationNotExistsError(), ) @patch("controllers.web.conversation.web_ns") - def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "X", "auto_generate": False} @@ -137,16 +138,16 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.pin") - def test_pin_success(self, mock_pin: MagicMock, app) -> None: + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) @@ -154,7 +155,7 @@ class TestConversationPinApi: assert result["result"] == "success" @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) - def test_pin_not_found(self, mock_pin: MagicMock, app) -> None: + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): with pytest.raises(NotFound): @@ -163,16 +164,16 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.unpin") - def test_unpin_success(self, mock_unpin: MagicMock, app) -> None: + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 635cfee2da..2c6a990240 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.web.forgot_password import ( ForgotPasswordCheckApi, @@ -29,7 +30,7 @@ def _patch_wraps(): class TestForgotPasswordSendEmailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") @@ -42,7 +43,7 @@ class TestForgotPasswordSendEmailApi: mock_rate_limit, mock_get_account, mock_send_mail, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -64,7 +65,7 @@ class TestForgotPasswordSendEmailApi: class TestForgotPasswordCheckApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") @@ -81,7 +82,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} @@ -117,7 +118,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} @@ -142,7 +143,7 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @@ -157,7 +158,7 @@ class TestForgotPasswordResetApi: mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() @@ -194,7 +195,7 @@ class TestForgotPasswordResetApi: mock_db, mock_token_bytes, mock_hash_password, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py index 19833cc772..de9e691434 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized @@ -182,7 +183,7 @@ class TestValidateUserAccessibility: class TestDecodeJwtToken: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True): diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index c342e8994b..bd13527e14 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -85,7 +85,7 @@ class TestPauseStatePersistenceLayerTestContainers: return WorkflowRunService(engine) @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): + def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus @@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers: generate_entity=entity, ) - def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): + def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session): """Test complete pause flow: event -> state serialization -> database save -> storage save.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert isinstance(persisted_entity, WorkflowAppGenerateEntity) assert persisted_entity.workflow_execution_id == self.test_workflow_run_id - def test_state_persistence_and_retrieval(self, db_session_with_containers): + def test_state_persistence_and_retrieval(self, db_session_with_containers: Session): """Test that pause state can be persisted and retrieved correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert retrieved_state["node_run_steps"] == 10 assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_database_transaction_handling(self, db_session_with_containers): + def test_database_transaction_handling(self, db_session_with_containers: Session): """Test that database transactions are handled correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert pause_model.resumed_at is None assert pause_model.state_object_key != "" - def test_file_storage_integration(self, db_session_with_containers): + def test_file_storage_integration(self, db_session_with_containers: Session): """Test integration with file storage system.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps() assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_workflow_with_different_creators(self, db_session_with_containers): + def test_workflow_with_different_creators(self, db_session_with_containers: Session): """Test pause state with workflows created by different users.""" # Arrange - Create workflow with different creator different_user_id = str(uuid.uuid4()) @@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers: resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id - def test_layer_ignores_non_pause_events(self, db_session_with_containers): + def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session): """Test that layer ignores non-pause events.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -562,7 +562,7 @@ class TestPauseStatePersistenceLayerTestContainers: ).all() assert len(pause_states) == 0 - def test_layer_requires_initialization(self, db_session_with_containers): + def test_layer_requires_initialization(self, db_session_with_containers: Session): """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index a60159c66a..54ee133bfe 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -15,6 +15,7 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client @@ -40,7 +41,7 @@ class TestTenantIsolatedTaskQueueIntegration: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -94,7 +95,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" assert queue._task_key == f"tenant_test-key_task:{tenant.id}" - def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake): + def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker): """Test that different tenants have isolated queues.""" tenant1, _ = test_tenant_and_account @@ -176,7 +177,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert len(remaining_tasks) == 2 assert remaining_tasks == ["task4", "task5"] - def test_push_and_pull_complex_objects(self, test_queue, fake): + def test_push_and_pull_complex_objects(self, test_queue, fake: Faker): """Test pushing and pulling complex object tasks.""" # Create complex task objects as dictionaries (not dataclass instances) tasks = [ @@ -218,7 +219,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert pulled_task["data"] == original_task["data"] assert pulled_task["metadata"] == original_task["metadata"] - def test_mixed_task_types(self, test_queue, fake): + def test_mixed_task_types(self, test_queue, fake: Faker): """Test pushing and pulling mixed string and object tasks.""" string_task = "simple_string_task" object_task = { @@ -267,7 +268,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Verify task key has expired assert test_queue.get_task_key() is None - def test_large_task_batch(self, test_queue, fake): + def test_large_task_batch(self, test_queue, fake: Faker): """Test handling large batches of tasks.""" # Create large batch of tasks large_batch = [] @@ -292,7 +293,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert isinstance(task, dict) assert task["index"] == i # FIFO order - def test_queue_operations_isolation(self, test_tenant_and_account, fake): + def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker): """Test concurrent operations on different queues.""" tenant, _ = test_tenant_and_account @@ -312,7 +313,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert tasks2 == ["task1_queue2", "task2_queue2"] assert tasks1 != tasks2 - def test_task_wrapper_serialization_roundtrip(self, test_queue, fake): + def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker): """Test TaskWrapper serialization and deserialization roundtrip.""" # Create complex nested data complex_data = { @@ -346,7 +347,7 @@ class TestTenantIsolatedTaskQueueIntegration: task = test_queue.pull_tasks(1) assert task[0] == invalid_json_task - def test_real_world_batch_processing_scenario(self, test_queue, fake): + def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker): """Test realistic batch processing scenario.""" # Simulate batch processing tasks batch_tasks = [] @@ -403,7 +404,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -435,7 +436,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return tenant, account - def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake): + def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker): """ Test compatibility with legacy queues containing only string data. @@ -465,7 +466,7 @@ class TestTenantIsolatedTaskQueueCompatibility: expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] assert pulled_tasks == expected_order - def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake): + def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker): """ Test complete migration scenario from legacy to new system. @@ -546,7 +547,7 @@ class TestTenantIsolatedTaskQueueCompatibility: assert task["tenant_id"] == tenant.id assert task["processing_type"] == "new_system" - def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake): + def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker): """ Test error recovery when legacy queue contains malformed data. diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 00d7496a40..9da6b04a2c 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw class TestGetAvailableDatasetsIntegration: def test_returns_datasets_with_available_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].name == dataset.name def test_filters_out_datasets_with_only_archived_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_only_disabled_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_non_completed_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_includes_external_datasets_without_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that external datasets are returned even with no available documents. @@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].id == dataset.id assert result[0].provider == "external" - def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): # Arrange fake = Faker() @@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].tenant_id == tenant1.id def test_returns_empty_list_when_no_datasets_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration: # Assert assert result == [] - def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies): + def test_returns_only_requested_dataset_ids( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): # Arrange fake = Faker() @@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration: class TestKnowledgeRetrievalIntegration: def test_knowledge_retrieval_with_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration: assert isinstance(result, list) def test_knowledge_retrieval_no_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration: assert result == [] def test_knowledge_retrieval_rate_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index 177fb95ff3..e71079829f 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_service import ApiKeyAuthService @@ -31,7 +32,7 @@ class TestApiKeyAuthService: def mock_args(self, category, provider, mock_credentials) -> dict: return {"category": category, "provider": provider, "credentials": mock_credentials} - def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False): binding = DataSourceApiKeyAuthBinding( tenant_id=tenant_id, category=category, @@ -44,7 +45,7 @@ class TestApiKeyAuthService: return binding def test_get_provider_auth_list_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() @@ -56,14 +57,16 @@ class TestApiKeyAuthService: assert len(tenant_results) == 1 assert tenant_results[0].provider == provider - def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_get_provider_auth_list_empty( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): result = ApiKeyAuthService.get_provider_auth_list(tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] def test_get_provider_auth_list_filters_disabled( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True @@ -78,7 +81,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_success( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -97,7 +106,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_validation_failed( - self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = False @@ -112,7 +121,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_encrypts_api_key( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -128,7 +143,13 @@ class TestApiKeyAuthService: mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) def test_get_auth_credentials_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + self, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + category, + provider, + mock_credentials, ): self._create_binding( db_session_with_containers, @@ -144,14 +165,14 @@ class TestApiKeyAuthService: assert result == mock_credentials def test_get_auth_credentials_not_found( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) assert result is None def test_get_auth_credentials_json_parsing( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} self._create_binding( @@ -169,7 +190,7 @@ class TestApiKeyAuthService: assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" def test_delete_provider_auth_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): binding = self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider @@ -183,7 +204,9 @@ class TestApiKeyAuthService: remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() assert remaining is None - def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_delete_provider_auth_not_found( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): # Should not raise when binding not found ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index f48c6da690..e78fa27976 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -10,6 +10,7 @@ from uuid import uuid4 import httpx import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory @@ -114,7 +115,7 @@ class TestAuthIntegration: assert result2[0].tenant_id == tenant_id_2 def test_cross_tenant_access_prevention( - self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category ): result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index 42d587b7f7..327f14ddfe 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType @@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument: "user_id": user_id, } - def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_waiting_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in waiting state. @@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument: mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True") def test_pause_document_indexing_state_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful pause of document in indexing state. @@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument: assert document.is_paused is True assert document.paused_by == mock_document_service_dependencies["user_id"] - def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_parsing_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in parsing state. @@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is True - def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_completed_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause completed document. @@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is False - def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_error_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause document in error state. @@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument: "recover_task": mock_task, } - def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_paused_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful recovery of paused document. @@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument: document.dataset_id, document.id ) - def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_not_paused_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to recover non-paused document. @@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument: "user_id": user_id, } - def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_single_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of single document. @@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument: dataset.id, [document.id], mock_document_service_dependencies["user_id"] ) - def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_multiple_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of multiple documents. @@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument: ) def test_retry_document_concurrent_retry_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is already being retried. @@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument: assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when current_user is missing. @@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: } def test_batch_update_document_status_enable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch enabling of documents. @@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: assert mock_document_service_dependencies["add_task"].delay.call_count == 2 def test_batch_update_document_status_disable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch disabling of documents. @@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_archive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch archiving of documents. @@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_unarchive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch unarchiving of documents. @@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_empty_list( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test handling of empty document list. @@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_not_called() def test_batch_update_document_status_document_indexing_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is being indexed. @@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument: "current_user": mock_current_user, } - def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies): """ Test successful document renaming. @@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument: assert result == document assert document.name == new_name - def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_built_in_fields( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with built-in fields enabled. @@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument: assert document.doc_metadata["document_name"] == new_name assert document.doc_metadata["existing_key"] == "existing_value" - def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_upload_file( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with associated upload file. @@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument: assert upload_file.name == new_name def test_rename_document_dataset_not_found_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when dataset is not found. @@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Dataset not found"): DocumentService.rename_document(dataset_id, document_id, new_name) - def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_not_found_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when document is not found. @@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Document not found"): DocumentService.rename_document(dataset.id, document_id, new_name) - def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_permission_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when user lacks permission. diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py index 4e8255d8ed..e73c2afe7f 100644 --- a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from redis import RedisError +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from models.account import TenantAccountJoin @@ -122,7 +123,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_multiple_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -144,7 +145,7 @@ class TestSyncAccountDeletion: assert queued_workspace_ids == set(tenant_ids) def test_sync_account_deletion_no_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: mock_config.ENTERPRISE_ENABLED = True @@ -155,7 +156,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_partial_failure( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -180,7 +181,7 @@ class TestSyncAccountDeletion: assert mock_queue_task.call_count == 3 def test_sync_account_deletion_all_failures( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_id = str(uuid4()) diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py index 2b842629a7..724dd19f92 100644 --- a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -3,6 +3,8 @@ from __future__ import annotations from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from models.model import App, RecommendedApp, Site from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_type import RecommendAppType @@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval: class TestFetchRecommendedAppsFromDb: - def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb: assert "assistant" in result["categories"] assert "writing" in result["categories"] - def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + def test_falls_back_to_default_language_when_empty( + self, flask_app_with_containers, db_session_with_containers: Session + ): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id in app_ids - def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_site(db_session_with_containers, app_id=app1.id) @@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id not in app_ids - def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb: class TestFetchRecommendedAppDetailFromDb: - def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session): result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) assert result is None - def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb: assert result is None @patch("services.recommend_app.database.database_retrieval.AppDslService") - def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 3ec265d009..f78037e503 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -2,6 +2,7 @@ import copy import pytest from faker import Faker +from sqlalchemy.orm import Session from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService: # for consistency with other test files return {} - def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_baichuan_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for Baichuan model. @@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_common_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for common models. @@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_prompt_case_insensitive_baichuan_detection( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan model detection is case insensitive. @@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text def test_get_common_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for chat app with completion mode. @@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_chat_app_chat_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation for chat app with chat mode. @@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with completion mode. @@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with chat mode. @@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService: assert CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation without context. @@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_common_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported app mode. @@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_common_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported model mode. @@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService: # Assert: Verify empty dict is returned assert result == {} - def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_completion_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test completion prompt generation with context. @@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService: assert result_text == CONTEXT + original_text def test_get_completion_prompt_without_context( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test completion prompt generation without context. @@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService: assert result_text == original_text assert CONTEXT not in result_text - def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation with context. @@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService: assert original_text in result_text assert result_text == CONTEXT + original_text - def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_without_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation without context. @@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService: assert CONTEXT not in result_text def test_get_baichuan_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with completion mode. @@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_chat_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with chat mode. @@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with completion mode. @@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with chat mode. @@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_baichuan_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test Baichuan prompt generation without context. @@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported app mode. @@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_baichuan_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported model mode. @@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_prompt_all_app_modes_common_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with common model. @@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService: assert result != {} def test_get_prompt_all_app_modes_baichuan_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with Baichuan model. @@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService: assert result is not None assert result != {} - def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test prompt generation with edge cases. @@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService: # Should either return a valid result or empty dict, but not crash assert result is not None - def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that original templates are not modified. @@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService: assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_baichuan_template_immutability( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that original Baichuan templates are not modified. @@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService: assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies): + def test_context_integration_consistency( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test consistency of context integration across different scenarios. @@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService: assert prompt_text.startswith(CONTEXT) def test_baichuan_context_integration_consistency( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test consistency of Baichuan context integration across different scenarios. diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 1835650c42..6b844615b5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -10,6 +10,8 @@ from uuid import uuid4 import pytest import yaml from faker import Faker +from flask import Flask +from sqlalchemy.orm import Session from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -88,7 +90,7 @@ class TestAppDslService: """Integration tests for AppDslService using testcontainers.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -129,7 +131,7 @@ class TestAppDslService: "enterprise_service": mock_enterprise_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): fake = Faker() with patch("services.account_service.FeatureService") as mock_account_feature_service: mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -206,7 +208,7 @@ class TestAppDslService: # ── Import: Validation ──────────────────────────────────────────── - def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers): + def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Invalid import_mode"): service.import_app( @@ -215,7 +217,7 @@ class TestAppDslService: yaml_content="version: '0.1.0'", ) - def test_import_app_missing_yaml_content(self, db_session_with_containers): + def test_import_app_missing_yaml_content(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -225,7 +227,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "yaml_content is required" in result.error - def test_import_app_missing_yaml_url(self, db_session_with_containers): + def test_import_app_missing_yaml_url(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -235,7 +237,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "yaml_url is required" in result.error - def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers): + def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -245,7 +247,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "content must be a mapping" in result.error - def test_import_app_version_not_str_returns_failed(self, db_session_with_containers): + def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) result = service.import_app( @@ -256,7 +258,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Invalid version type" in result.error - def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers): + def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -266,7 +268,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Missing app data" in result.error - def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): def bad_safe_load(_content: str): raise yaml.YAMLError("bad") @@ -281,7 +283,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert result.error.startswith("Invalid YAML format:") - def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): monkeypatch.setattr( AppDslService, "_create_or_update_app", @@ -299,7 +301,7 @@ class TestAppDslService: # ── Import: YAML URL ────────────────────────────────────────────── - def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): monkeypatch.setattr( app_dsl_service.ssrf_proxy, "get", @@ -315,7 +317,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Error fetching YAML from URL: boom" in result.error - def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch): response = MagicMock() response.content = b"" response.raise_for_status.return_value = None @@ -330,7 +332,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Empty content" in result.error - def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch): response = MagicMock() response.content = b"x" * (DSL_MAX_SIZE + 1) response.raise_for_status.return_value = None @@ -345,7 +347,9 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "File size exceeds" in result.error - def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_user_attachments_keeps_original_url( + self, db_session_with_containers: Session, monkeypatch + ): yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml" yaml_bytes = _pending_yaml_content() @@ -371,7 +375,7 @@ class TestAppDslService: assert result.imported_dsl_version == "99.0.0" assert requested_urls == [yaml_url] - def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch): yaml_url = "https://github.com/acme/repo/blob/main/app.yml" raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml" yaml_bytes = _pending_yaml_content() @@ -400,7 +404,7 @@ class TestAppDslService: # ── Import: App ID checks ──────────────────────────────────────── - def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers): + def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -412,7 +416,7 @@ class TestAppDslService: assert result.error == "App not found" def test_import_app_overwrite_only_allows_workflow_and_advanced_chat( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) assert app.mode == "chat" @@ -429,7 +433,7 @@ class TestAppDslService: # ── Import: Flow ────────────────────────────────────────────────── - def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers): + def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -449,7 +453,7 @@ class TestAppDslService: assert stored is not None def test_import_app_completed_uses_declared_dependencies( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) @@ -483,7 +487,7 @@ class TestAppDslService: @pytest.mark.parametrize("has_workflow", [True, False]) def test_import_app_legacy_versions_extract_dependencies( - self, db_session_with_containers, monkeypatch, has_workflow: bool + self, db_session_with_containers: Session, monkeypatch, has_workflow: bool ): monkeypatch.setattr( AppDslService, @@ -540,13 +544,13 @@ class TestAppDslService: # ── Confirm Import ──────────────────────────────────────────────── - def test_confirm_import_expired_returns_failed(self, db_session_with_containers): + def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.confirm_import(import_id=str(uuid4()), account=_account_mock()) assert result.status == ImportStatus.FAILED assert "expired" in result.error - def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch): + def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" @@ -579,7 +583,7 @@ class TestAppDslService: assert result.app_id == created_app.id assert redis_client.get(redis_key) is None - def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers): + def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123") @@ -589,7 +593,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "validation error" in result.error - def test_confirm_import_exception_returns_failed(self, db_session_with_containers): + def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json") @@ -600,13 +604,13 @@ class TestAppDslService: # ── Check Dependencies ──────────────────────────────────────────── - def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers): + def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) app_model = _app_stub() result = service.check_dependencies(app_model=app_model) assert result.leaked_dependencies == [] - def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch): + def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch): app_id = str(uuid4()) pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id) redis_client.setex( @@ -634,7 +638,9 @@ class TestAppDslService: result = service.check_dependencies(app_model=_app_stub(id=app_id)) assert len(result.leaked_dependencies) == 1 - def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_dependencies_with_real_app( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' @@ -650,12 +656,12 @@ class TestAppDslService: # ── Create/Update App ───────────────────────────────────────────── - def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers): + def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="loss app mode"): service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) - def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch): + def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch): fixed_now = object() monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) @@ -707,7 +713,7 @@ class TestAppDslService: assert app.icon_background == "#222222" assert app.updated_at is fixed_now - def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers): + def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session): account = _account_mock() account.current_tenant_id = None service = AppDslService(db_session_with_containers) @@ -719,7 +725,7 @@ class TestAppDslService: ) def test_create_or_update_app_creates_workflow_app_and_saves_dependencies( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) @@ -755,7 +761,7 @@ class TestAppDslService: stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}") assert stored is not None - def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers): + def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing workflow data"): service._create_or_update_app( @@ -764,7 +770,7 @@ class TestAppDslService: account=_account_mock(), ) - def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers): + def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing model_config"): service._create_or_update_app( @@ -774,7 +780,7 @@ class TestAppDslService: ) def test_create_or_update_app_chat_creates_model_config_and_sends_event( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.app_model_config_id = None @@ -793,7 +799,7 @@ class TestAppDslService: db_session_with_containers.expire_all() assert app.app_model_config_id is not None - def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers): + def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Invalid app mode"): service._create_or_update_app( @@ -870,7 +876,7 @@ class TestAppDslService: assert data["app"]["icon_type"] == "image" assert data["app"]["icon_background"] == "#FFEAD5" - def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) model_config = AppModelConfig( @@ -908,7 +914,9 @@ class TestAppDslService: assert "model_config" in exported_data assert "dependencies" in exported_data - def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_workflow_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" db_session_with_containers.commit() @@ -941,7 +949,9 @@ class TestAppDslService: assert "workflow" in exported_data assert "dependencies" in exported_data - def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_with_workflow_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" db_session_with_containers.commit() @@ -981,7 +991,7 @@ class TestAppDslService: assert "workflow" in exported_data def test_export_dsl_with_invalid_workflow_id_raises_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py index 768a8baee2..d0c07f0de8 100644 --- a/api/tests/test_containers_integration_tests/services/test_attachment_service.py +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -7,7 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound import services.attachment_service as attachment_service_module @@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService class TestAttachmentService: - def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile: upload_file = UploadFile( tenant_id=tenant_id or str(uuid4()), storage_type=StorageType.OPENDAL, @@ -60,7 +60,7 @@ class TestAttachmentService: with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): AttachmentService(session_factory=invalid_session_factory) - def test_should_return_base64_when_file_exists(self, db_session_with_containers): + def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session): upload_file = self._create_upload_file(db_session_with_containers) service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) @@ -70,7 +70,7 @@ class TestAttachmentService: assert result == base64.b64encode(b"binary-content").decode() mock_load.assert_called_once_with(upload_file.key) - def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session): service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) with patch.object(attachment_service_module.storage, "load_once") as mock_load: diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 8092c7ad75..4893126d7f 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from extensions.ext_redis import redis_client @@ -24,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache: """ @pytest.fixture(autouse=True) - def setup_redis_cleanup(self, flask_app_with_containers): + def setup_redis_cleanup(self, flask_app_with_containers: Flask): """Clean up Redis cache before and after each test.""" with flask_app_with_containers.app_context(): # Clean up before test @@ -56,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache: return value return None - def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -87,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was not called mock_get_plan_bulk.assert_not_called() - def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are not in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -127,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1 > 0 assert ttl_1 <= 600 # Should be <= 600 seconds - def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when some tenants are in cache, some are not.""" with flask_app_with_containers.app_context(): # Arrange @@ -158,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == missing_plan["tenant-3"] - def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask): """Test fallback to API when Redis mget fails.""" with flask_app_with_containers.app_context(): # Arrange @@ -189,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert cached_1 is not None assert cached_2 is not None - def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache contains invalid JSON.""" with flask_app_with_containers.app_context(): # Arrange @@ -241,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == expected_plans["tenant-3"] - def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache data doesn't match SubscriptionPlan schema.""" with flask_app_with_containers.app_context(): # Arrange @@ -274,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was called for tenant-2 and tenant-3 mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) - def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask): """Test that pipeline failure doesn't affect return value.""" with flask_app_with_containers.app_context(): # Arrange @@ -303,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify pipeline was attempted mock_pipeline.assert_called_once() - def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask): """Test with empty tenant_ids list.""" with flask_app_with_containers.app_context(): # Act @@ -321,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache: # But we should check that mget was not called at all # Since we can't easily verify this without more mocking, we just verify the result - def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask): """Test that expired cache keys are treated as cache misses.""" with flask_app_with_containers.app_context(): # Arrange diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 98c38f2b5f..8aa10129c1 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory: class TestConversationServicePagination: """Test conversation pagination operations.""" - def test_pagination_with_non_empty_include_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session): """ Test that non-empty include_ids filters properly. @@ -204,7 +205,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[0].id, conversations[1].id} - def test_pagination_with_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that empty exclude_ids doesn't filter. @@ -237,7 +238,7 @@ class TestConversationServicePagination: # Assert assert len(result.data) == len(conversations) - def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that non-empty exclude_ids filters properly. @@ -271,7 +272,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[2].id} - def test_pagination_with_sorting_descending(self, db_session_with_containers): + def test_pagination_with_sorting_descending(self, db_session_with_containers: Session): """ Test pagination with descending sort order. @@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation: within conversations. """ - def test_pagination_by_first_id_without_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session): """ Test message pagination without specifying first_id. @@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 3 # All 3 messages returned assert result.has_more is False # No more messages available (3 < limit of 10) - def test_pagination_by_first_id_with_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session): """ Test message pagination with first_id specified. @@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 2 # Only 2 messages returned after first_id assert result.has_more is False # No more messages available (2 < limit of 10) - def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers): + def test_pagination_by_first_id_raises_error_when_first_message_not_found( + self, db_session_with_containers: Session + ): """ Test that FirstMessageNotExistsError is raised when first_id doesn't exist. @@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation: limit=10, ) - def test_pagination_with_has_more_flag(self, db_session_with_containers): + def test_pagination_with_has_more_flag(self, db_session_with_containers: Session): """ Test that has_more flag is correctly set when there are more messages. @@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set - def test_pagination_with_ascending_order(self, db_session_with_containers): + def test_pagination_with_ascending_order(self, db_session_with_containers: Session): """ Test message pagination with ascending order. @@ -512,7 +515,7 @@ class TestConversationServiceSummarization: """ @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session): """ Test successful auto-generation of conversation name. @@ -552,7 +555,7 @@ class TestConversationServiceSummarization: app_model.tenant_id, first_message.query, conversation.id, app_model.id ) - def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers): + def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session): """ Test that MessageNotExistsError is raised when conversation has no messages. @@ -571,7 +574,9 @@ class TestConversationServiceSummarization: ConversationService.auto_generate_name(app_model, conversation) @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_handles_llm_failure_gracefully( + self, mock_llm_generator, db_session_with_containers: Session + ): """ Test that LLM generation failures are suppressed and don't crash. @@ -604,7 +609,7 @@ class TestConversationServiceSummarization: assert conversation.name == original_name # Name remains unchanged @patch("services.conversation_service.naive_utc_now") - def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers): + def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session): """ Test renaming conversation with manual name. @@ -638,7 +643,7 @@ class TestConversationServiceSummarization: assert conversation.updated_at == mock_time @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers): + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session): """ Test rename delegates to auto_generate_name when auto_generate is True. @@ -682,7 +687,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_from_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating annotation from existing message. @@ -721,7 +728,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_without_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating standalone annotation without message. @@ -753,7 +762,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test updating an existing annotation. @@ -800,7 +809,7 @@ class TestConversationServiceMessageAnnotation: mock_add_task.delay.assert_not_called() @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving paginated annotation list. @@ -836,7 +845,7 @@ class TestConversationServiceMessageAnnotation: assert result_total == 5 @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving annotations with keyword filtering. @@ -885,7 +894,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test direct annotation insertion without message reference. @@ -919,7 +928,7 @@ class TestConversationServiceExport: Tests retrieving conversation data for export purposes. """ - def test_get_conversation_success(self, db_session_with_containers): + def test_get_conversation_success(self, db_session_with_containers: Session): """Test successful retrieval of conversation.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -937,7 +946,7 @@ class TestConversationServiceExport: # Assert assert result == conversation - def test_get_conversation_not_found(self, db_session_with_containers): + def test_get_conversation_not_found(self, db_session_with_containers: Session): """Test ConversationNotExistsError when conversation doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -949,7 +958,7 @@ class TestConversationServiceExport: ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user) @patch("services.annotation_service.current_account_with_tenant") - def test_export_annotation_list(self, mock_current_account, db_session_with_containers): + def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session): """Test exporting all annotations for an app.""" # Arrange app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -977,7 +986,7 @@ class TestConversationServiceExport: # Assert assert len(result) == 10 - def test_get_message_success(self, db_session_with_containers): + def test_get_message_success(self, db_session_with_containers: Session): """Test successful retrieval of a message.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -1001,7 +1010,7 @@ class TestConversationServiceExport: # Assert assert result == message - def test_get_message_not_found(self, db_session_with_containers): + def test_get_message_not_found(self, db_session_with_containers: Session): """Test MessageNotExistsError when message doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -1012,7 +1021,7 @@ class TestConversationServiceExport: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4())) - def test_get_conversation_for_end_user(self, db_session_with_containers): + def test_get_conversation_for_end_user(self, db_session_with_containers: Session): """ Test retrieving conversation created by end user via API. @@ -1038,7 +1047,7 @@ class TestConversationServiceExport: assert result == conversation @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session): """ Test conversation deletion with async cleanup. @@ -1071,7 +1080,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_called_once_with(conversation_id) @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session): """ Test deletion is denied when conversation belongs to a different account. """ @@ -1102,7 +1111,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_not_called() @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers): + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session): """ Test that delete propagates exceptions and does not trigger the cleanup task. diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py index 0b7bd9ca64..6c292dbc4b 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -5,7 +5,8 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from sqlalchemy.orm import sessionmaker +from flask import Flask +from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db @@ -149,7 +150,7 @@ class ConversationServiceVariableIntegrationFactory: @pytest.fixture -def real_conversation_service_session_factory(flask_app_with_containers): +def real_conversation_service_session_factory(flask_app_with_containers: Flask): del flask_app_with_containers real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -162,7 +163,7 @@ def real_conversation_service_session_factory(flask_app_with_containers): class TestConversationServiceVariables: def test_get_conversational_variable_success( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -200,7 +201,7 @@ class TestConversationServiceVariables: assert result.has_more is False def test_get_conversational_variable_with_last_id( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -242,7 +243,7 @@ class TestConversationServiceVariables: assert result.has_more is False def test_get_conversational_variable_last_id_not_found_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -259,7 +260,7 @@ class TestConversationServiceVariables: ) def test_get_conversational_variable_sets_has_more( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -287,7 +288,7 @@ class TestConversationServiceVariables: assert result.has_more is True def test_update_conversation_variable_success( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -320,7 +321,7 @@ class TestConversationServiceVariables: assert result["updated_at"] == updated_at def test_update_conversation_variable_not_found_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -337,7 +338,7 @@ class TestConversationServiceVariables: ) def test_update_conversation_variable_type_mismatch_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -360,7 +361,7 @@ class TestConversationServiceVariables: ) def test_update_conversation_variable_integer_number_compatibility( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -390,7 +391,7 @@ class TestConversationServiceVariables: class TestConversationServicePaginationWithContainers: - def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers): + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) @@ -404,7 +405,7 @@ class TestConversationServicePaginationWithContainers: invoke_from=InvokeFrom.WEB_APP, ) - def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers): + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) base_time = datetime(2024, 1, 1, 8, 0, 0) @@ -442,7 +443,7 @@ class TestConversationServicePaginationWithContainers: assert newest.id != middle.id assert [conversation.id for conversation in result.data] == [oldest.id] - def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers): + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") @@ -462,7 +463,7 @@ class TestConversationServicePaginationWithContainers: assert alpha.id != beta.id assert [conversation.id for conversation in result.data] == [gamma.id] - def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers): + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) end_user = factory.create_end_user(db_session_with_containers, app) @@ -493,7 +494,7 @@ class TestConversationServicePaginationWithContainers: assert account_conversation.id != end_user_conversation.id assert [conversation.id for conversation in result.data] == [end_user_conversation.id] - def test_pagination_filters_to_account_console_source(self, db_session_with_containers): + def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) end_user = factory.create_end_user(db_session_with_containers, app) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 02ab3f8314..638a962f18 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pytest -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from extensions.ext_database import db from graphon.variables import StringVariable @@ -13,7 +13,12 @@ from services.conversation_variable_updater import ConversationVariableNotFoundE class TestConversationVariableUpdater: def _create_conversation_variable( - self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + self, + db_session_with_containers: Session, + *, + conversation_id: str, + variable: StringVariable, + app_id: str | None = None, ) -> ConversationVariable: row = ConversationVariable( id=variable.id, @@ -25,7 +30,7 @@ class TestConversationVariableUpdater: db_session_with_containers.commit() return row - def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="old value") self._create_conversation_variable( @@ -42,7 +47,7 @@ class TestConversationVariableUpdater: assert row is not None assert row.data == updated_variable.model_dump_json() - def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="value") updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) @@ -50,7 +55,7 @@ class TestConversationVariableUpdater: with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): updater.update(conversation_id=conversation_id, variable=variable) - def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session): updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) result = updater.flush() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 0f63d98642..09ba041244 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -3,6 +3,7 @@ from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.errors.error import QuotaExceededError from models import TenantCreditPool @@ -14,7 +15,7 @@ class TestCreditPoolService: def _create_tenant_id(self) -> str: return str(uuid4()) - def test_create_default_pool(self, db_session_with_containers): + def test_create_default_pool(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) @@ -25,7 +26,7 @@ class TestCreditPoolService: assert pool.quota_used == 0 assert pool.quota_limit > 0 - def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -35,17 +36,17 @@ class TestCreditPoolService: assert result.tenant_id == tenant_id assert result.pool_type == ProviderQuotaType.TRIAL - def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session): result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None - def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session): result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) assert result is False - def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -53,7 +54,7 @@ class TestCreditPoolService: assert result is True - def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) # Exhaust credits @@ -64,11 +65,11 @@ class TestCreditPoolService: assert result is False - def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session): with pytest.raises(QuotaExceededError, match="Credit pool not found"): CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) - def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) pool.quota_used = pool.quota_limit @@ -77,7 +78,7 @@ class TestCreditPoolService: with pytest.raises(QuotaExceededError, match="No credits remaining"): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) - def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) credits_required = 10 @@ -89,7 +90,7 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 71c8874f79..f9898e2cfa 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db @@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory: class TestDatasetPermissionServiceGetPartialMemberList: """Verify partial-member list reads against persisted DatasetPermission rows.""" - def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session): """ Test retrieving partial member list with multiple members. """ @@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 3 - def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session): """ Test retrieving partial member list with single member. """ @@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 1 - def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session): """ Test retrieving partial member list when no members exist. """ @@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: class TestDatasetPermissionServiceUpdatePartialMemberList: """Verify partial-member list updates against persisted DatasetPermission rows.""" - def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session): """ Test adding new partial members to a dataset. """ @@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {member_1.id, member_2.id} - def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session): """ Test replacing existing partial members with new ones. """ @@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {new_member_1.id, new_member_2.id} - def test_update_partial_member_list_empty_list(self, db_session_with_containers): + def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test updating with empty member list (clearing all members). """ @@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: class TestDatasetPermissionServiceClearPartialMemberList: """Verify partial-member clearing against persisted DatasetPermission rows.""" - def test_clear_partial_member_list_success(self, db_session_with_containers): + def test_clear_partial_member_list_success(self, db_session_with_containers: Session): """ Test successful clearing of partial member list. """ @@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test clearing partial member list when no members exist. """ @@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" - def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session): """Test that users from different tenants cannot access dataset.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other_user) - def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session): """Test that tenant owners can access any dataset regardless of permission level.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, owner) - def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session): """Test ONLY_ME permission allows only the dataset creator to access.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, creator) - def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session): """Test ONLY_ME permission denies access to non-creators.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other) - def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session): """Test ALL_TEAM permission allows any team member to access the dataset.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, member) - def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_with_permission_success( + self, db_session_with_containers: Session + ): """ Test that user with explicit permission can access partial_members dataset. """ @@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission: permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert user.id in permissions - def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_without_permission_error( + self, db_session_with_containers: Session + ): """ Test error when user without permission tries to access partial_members dataset. """ @@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session): """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f..e6ee896a52 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration: class TestDocumentServicePauseRecoverRetry: """Tests for pause/recover/retry orchestration using real DB and Redis.""" - def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"): factory = DatasetServiceIntegrationDataFactory account, tenant = factory.create_account_with_tenant(db_session_with_containers) dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) @@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry: db_session_with_containers.commit() return doc, account - def test_pause_document_success(self, db_session_with_containers): + def test_pause_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is not None redis_client.delete(cache_key) - def test_pause_document_invalid_status_error(self, db_session_with_containers): + def test_pause_document_invalid_status_error(self, db_session_with_containers: Session): from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry: with pytest.raises(DocumentIndexingError): DocumentService.pause_document(doc) - def test_recover_document_success(self, db_session_with_containers): + def test_recover_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is None recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) - def test_retry_document_indexing_success(self, db_session_with_containers): + def test_retry_document_indexing_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py index c486ff5613..08de79f4b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -6,6 +6,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin from services.dataset_service import DatasetService @@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset: permission="only_me", ) - def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session): tenant, _ = self._create_tenant_and_account(db_session_with_containers) mock_user = Mock(id=None) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 3cac964d89..c43a5d5978 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,8 @@ from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory: class TestDatasetServiceDeleteDataset: """Integration coverage for DatasetService.delete_dataset using testcontainers.""" - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset: dataset.pipeline_id, ) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """Return False without scheduling cleanup when the target dataset does not exist.""" # Arrange owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py index 1b4179c9c7..0603a1e27f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -6,6 +6,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -363,7 +364,7 @@ class TestDatasetServicePermissionsAndLifecycle: DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) - def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask): with flask_app_with_containers.app_context(): with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(str(uuid4()), True) @@ -473,7 +474,7 @@ class TestDatasetCollectionBindingServiceIntegration: assert persisted.type == "dataset" assert persisted.collection_name - def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask): with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae516..69c39b8bfb 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -6,6 +6,7 @@ from datetime import UTC, datetime, timedelta from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.commit() return run - def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None: + def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None: archive_log = WorkflowArchiveLog( tenant_id=run.tenant_id, app_id=run.app_id, @@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.add(archive_log) db_session_with_containers.commit() - def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session): deleter = ArchivedWorkflowRunDeletion() missing_run_id = str(uuid4()) @@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {missing_run_id} not found" - def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {run.id} is not archived" - def test_delete_batch_uses_repo(self, db_session_with_containers): + def test_delete_batch_uses_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) base_time = datetime.now(UTC) run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time) @@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion: ).all() assert remaining_runs == [] - def test_delete_run_calls_repo(self, db_session_with_containers): + def test_delete_run_calls_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion: deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None - def test_delete_run_dry_run(self, db_session_with_containers): + def test_delete_run_dry_run(self, db_session_with_containers: Session): """Dry run should return success without actually deleting.""" tenant_id = str(uuid4()) run = self._create_workflow_run( @@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(WorkflowRun, run_id) is not None - def test_delete_run_exception_returns_error(self, db_session_with_containers): + def test_delete_run_exception_returns_error(self, db_session_with_containers: Session): """Exception during deletion should return failure result.""" from unittest.mock import MagicMock, patch @@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == "Database error" - def test_delete_by_run_id_success(self, db_session_with_containers): + def test_delete_by_run_id_success(self, db_session_with_containers: Session): """Successfully delete an archived workflow run by ID.""" tenant_id = str(uuid4()) base_time = datetime.now(UTC) @@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() assert db_session_with_containers.get(WorkflowRun, run_id) is None - def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session): """_get_workflow_run_repo should return a cached repo on subsequent calls.""" deleter = ArchivedWorkflowRunDeletion() diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index cafabc939b..074d448aab 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user with custom user_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser: assert result.type == InvokeFrom.SERVICE_API assert result.is_anonymous is False - def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user without user_id uses default session.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser: # Verify _is_anonymous is set correctly (property always returns False) assert result._is_anonymous is True - def test_get_existing_end_user(self, db_session_with_containers, factory): + def test_get_existing_end_user(self, db_session_with_containers: Session, factory): """Test retrieving an existing end user.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_create_end_user_service_api_type(self, db_session_with_containers, factory): + def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory): """Test creating new end user with SERVICE_API type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.app_id == app_id assert result.session_id == user_id - def test_create_end_user_web_app_type(self, db_session_with_containers, factory): + def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory): """Test creating new end user with WEB_APP type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.WEB_APP @patch("services.end_user_service.logger") - def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory): + def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory): """Test upgrading legacy end user with different type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert "Upgrading legacy EndUser" in log_call @patch("services.end_user_service.logger") - def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory): + def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory): """Test retrieving existing end user with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.SERVICE_API mock_logger.info.assert_not_called() - def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory): + def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory): """Test creating anonymous user when user_id is None.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result._is_anonymous is True assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory): + def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory): """Test that query ordering prioritizes records with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.id == matching.id assert result.id != non_matching.id - def test_external_user_id_matches_session_id(self, db_session_with_containers, factory): + def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory): """Test that external_user_id is set to match session_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType: InvokeFrom.DEBUGGER, ], ) - def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_end_user_with_different_invoke_types( + self, db_session_with_containers: Session, invoke_type, factory + ): """Test creating end users with different InvokeFrom types.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) existing_user = factory.create_end_user( db_session_with_containers, @@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById: assert result is not None assert result.id == existing_user.id - def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) result = EndUserService.get_end_user_by_id( @@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch: def factory(self): return TestEndUserServiceFactory() - def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3): """Create multiple apps under the same tenant.""" first_app = factory.create_app_and_account(db_session_with_containers) tenant_id = first_app.tenant_id @@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch: all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() return tenant_id, all_apps - def test_create_batch_empty_app_ids(self, db_session_with_containers): + def test_create_batch_empty_app_ids(self, db_session_with_containers: Session): result = EndUserService.create_end_user_batch( type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" ) assert result == {} - def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch: assert result[app_id].session_id == user_id assert result[app_id].type == InvokeFrom.SERVICE_API - def test_create_batch_default_session_id(self, db_session_with_containers, factory): + def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] @@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch: assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID assert end_user._is_anonymous is True - def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] user_id = f"user-{uuid4()}" @@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch: assert len(result) == 2 - def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch: for app_id in app_ids: assert first_result[app_id].id == second_result[app_id].id - def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) user_id = f"user-{uuid4()}" @@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch: "invoke_type", [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], ) - def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) user_id = f"user-{uuid4()}" diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 315936d721..f78aeaf984 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from services.feature_service import ( @@ -81,7 +82,7 @@ class TestFeatureService: fake = Faker() return fake.uuid4() - def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful feature retrieval with billing and enterprise enabled. @@ -156,7 +157,7 @@ class TestFeatureService: tenant_id ) - def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval for sandbox plan with specific limitations. @@ -222,7 +223,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_knowledge_rate_limit_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful knowledge rate limit retrieval with billing enabled. @@ -255,7 +258,7 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system features retrieval with enterprise and marketplace enabled. @@ -332,7 +335,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_unauthenticated( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval for an unauthenticated user. @@ -386,7 +391,9 @@ class TestFeatureService: # Marketplace should be visible assert result.enable_marketplace is True - def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_basic_config( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with basic configuration (no enterprise). @@ -436,7 +443,9 @@ class TestFeatureService: # Verify plugin package size (uses default value from dify_config) assert result.max_plugin_package_size == 15728640 - def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_billing_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval when billing is disabled. @@ -492,7 +501,7 @@ class TestFeatureService: assert result.webapp_copyright_enabled is False def test_get_knowledge_rate_limit_billing_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test knowledge rate limit retrieval when billing is disabled. @@ -523,7 +532,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() - def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_enterprise_only( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with enterprise enabled but billing disabled. @@ -583,7 +594,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_not_called() def test_get_system_features_enterprise_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval when enterprise is disabled. @@ -640,7 +651,7 @@ class TestFeatureService: # Verify no enterprise service calls mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() - def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval without tenant ID (billing disabled). @@ -686,7 +697,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_info.assert_not_called() - def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_partial_billing_info( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with partial billing information. @@ -746,7 +759,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_vector_space( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case vector space configuration. @@ -807,7 +822,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_webapp_auth( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case webapp auth configuration. @@ -863,7 +878,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_members_quota( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case members quota configuration. @@ -924,7 +941,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_plugin_installation_permission_scopes( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with different plugin installation permission scopes. @@ -1023,7 +1040,7 @@ class TestFeatureService: assert result.plugin_installation_permission.restrict_to_marketplace_only is True def test_get_features_workspace_members_missing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members info is missing from enterprise. @@ -1064,7 +1081,9 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_inactive( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with inactive license. @@ -1117,7 +1136,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_system_features_partial_enterprise_info( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with partial enterprise information. @@ -1186,7 +1205,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_limits( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case limit values. @@ -1244,7 +1265,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_protocols( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case protocol values. @@ -1297,7 +1318,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_education( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case education configuration. @@ -1353,7 +1376,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_license_limitation_model_is_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test LicenseLimitationModel.is_available method with various scenarios. @@ -1394,7 +1417,7 @@ class TestFeatureService: assert exact_limit.is_available(3) is True def test_get_features_workspace_members_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members are disabled in enterprise. @@ -1433,7 +1456,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) - def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_expired( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with expired license. @@ -1486,7 +1511,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_docs_processing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case document processing configuration. @@ -1544,7 +1569,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_branding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case branding configuration. @@ -1606,7 +1631,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_annotation_quota( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case annotation quota configuration. @@ -1668,7 +1693,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_features_edge_case_documents_upload( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case documents upload settings. @@ -1733,7 +1758,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_license_lost( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features with lost license status. @@ -1784,7 +1809,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_education_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with education feature disabled. diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index ed75363f3b..ce63e7a71a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -6,6 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session from configs import dify_config from core.workflow.human_input_adapter import ( @@ -88,7 +89,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, flask_app_with_containers, db_session_with_containers): + def test_default(self, flask_app_with_containers, db_session_with_containers: Session): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account = Account(name="Test User", email="member@example.com") db_session_with_containers.add(account) @@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index b55a19eaa9..fffa82bf5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from models.dataset import Dataset, DatasetMetadataBinding, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -65,7 +66,7 @@ class TestMetadataPartialUpdate: yield account def test_partial_update_merges_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -92,7 +93,7 @@ class TestMetadataPartialUpdate: assert updated_doc.doc_metadata["new_key"] == "new_value" def test_full_update_replaces_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -119,7 +120,7 @@ class TestMetadataPartialUpdate: assert "existing_key" not in updated_doc.doc_metadata def test_partial_update_skips_existing_binding( - self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -159,7 +160,7 @@ class TestMetadataPartialUpdate: assert len(bindings) == 1 def test_rollback_called_on_commit_failure( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py index c146a5924b..5fa5de6d80 100644 --- a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from models.model import OAuthProviderApp @@ -25,7 +26,7 @@ from services.oauth_server import ( class TestOAuthServerServiceGetProviderApp: """DB-backed tests for get_oauth_provider_app.""" - def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp: app = OAuthProviderApp( app_icon="icon.png", client_id=client_id, @@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp: db_session_with_containers.commit() return app - def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session): client_id = f"client-{uuid4()}" created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) @@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp: assert result.client_id == client_id assert result.id == created.id - def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session): result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index 7036524918..2f20949611 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -8,6 +8,7 @@ from datetime import datetime from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore @@ -39,7 +40,7 @@ class TestWorkflowRunRestore: assert result["created_at"].month == 1 assert result["name"] == "test" - def test_restore_table_records_returns_rowcount(self, db_session_with_containers): + def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() record_id = str(uuid4()) @@ -65,7 +66,7 @@ class TestWorkflowRunRestore: restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id)) assert restored_pause is not None - def test_restore_table_records_unknown_table(self, db_session_with_containers): + def test_restore_table_records_unknown_table(self, db_session_with_containers: Session): """Unknown table names should be ignored gracefully.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 970da98c55..6d5c7380b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from flask import Flask +from sqlalchemy.orm import Session from werkzeug.datastructures import FileStorage from models.enums import AppTriggerStatus, AppTriggerType @@ -52,7 +53,7 @@ class TestWebhookService: } @pytest.fixture - def test_data(self, db_session_with_containers, mock_external_dependencies): + def test_data(self, db_session_with_containers: Session, mock_external_dependencies): """Create test data for webhook service tests.""" fake = Faker() @@ -160,7 +161,7 @@ class TestWebhookService: "app_trigger": app_trigger, } - def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers: Flask): """Test successful retrieval of webhook trigger and workflow.""" webhook_id = test_data["webhook_id"] @@ -175,7 +176,7 @@ class TestWebhookService: assert node_config["id"] == "webhook_node" assert node_config["data"].title == "Test Webhook" - def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers: Flask): """Test webhook trigger not found scenario.""" with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Webhook not found"): @@ -421,7 +422,9 @@ class TestWebhookService: assert result["files"] == {} - def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers): + def test_trigger_workflow_execution_success( + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask + ): """Test successful workflow execution trigger.""" webhook_data = { "method": "POST", @@ -452,7 +455,7 @@ class TestWebhookService: mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once() def test_trigger_workflow_execution_end_user_service_failure( - self, test_data, mock_external_dependencies, flask_app_with_containers + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask ): """Test workflow execution trigger when EndUserService fails.""" webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py index 85ce3a6ba6..69cde847f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy import select from sqlalchemy.orm import Session @@ -165,7 +166,7 @@ class WebhookServiceRelationshipFactory: class TestWebhookServiceLookupWithContainers: def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -182,7 +183,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -202,7 +203,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -222,7 +223,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -239,7 +240,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -275,7 +276,7 @@ class TestWebhookServiceLookupWithContainers: class TestWebhookServiceTriggerExecutionWithContainers: def test_trigger_workflow_execution_triggers_async_workflow_successfully( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -318,7 +319,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: assert trigger_args[2].root_node_id == webhook_trigger.node_id def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -354,7 +355,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: mock_mark_rate_limited.assert_called_once_with(tenant.id) def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -386,7 +387,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: class TestWebhookServiceRelationshipSyncWithContainers: def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -401,7 +402,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: WebhookService.sync_webhook_relationships(app, workflow) def test_sync_webhook_relationships_raises_when_lock_not_acquired( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -418,7 +419,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: WebhookService.sync_webhook_relationships(app, workflow) def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -455,7 +456,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None def test_sync_webhook_relationships_sets_redis_cache_for_new_record( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -481,7 +482,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert cached_payload["webhook_id"] == "cache-webhook-id-00001" def test_sync_webhook_relationships_logs_when_lock_release_fails( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 1e57b5603d..a2cdddad61 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1530,7 +1530,7 @@ class TestWorkflowAppService: assert result_cross_tenant["total"] == 0 def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1543,7 +1543,7 @@ class TestWorkflowAppService: ) def test_get_paginate_workflow_app_logs_filters_by_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1558,7 +1558,9 @@ class TestWorkflowAppService: assert result["total"] >= 0 assert isinstance(result["data"], list) - def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_archive_logs( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7..82fe391b08 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -45,7 +45,9 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, fake: Faker | None = None + ): """ Helper method to create a test app with realistic data for testing. @@ -80,7 +82,7 @@ class TestWorkflowDraftVariableService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index b5ce8a53de..9ba1fda08b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -12,7 +12,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models import Account, App, Workflow +from models import Account, AccountStatus, App, TenantStatus, Workflow from models.model import AppMode from models.workflow import WorkflowType from services.workflow_service import WorkflowService @@ -33,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -49,7 +49,7 @@ class TestWorkflowService: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", # Set interface language for Site creation ) account.created_at = fake.date_time_this_year() @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -77,7 +77,7 @@ class TestWorkflowService: return account - def _create_test_app(self, db_session_with_containers: Session, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test app with realistic data. @@ -109,7 +109,7 @@ class TestWorkflowService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py index 29e1e240b4..afc4908c15 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -100,7 +100,7 @@ class TestWorkflowDeletion: session.flush() return provider - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -118,7 +118,7 @@ class TestWorkflowDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(Workflow, workflow_id) is None - def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + def test_delete_draft_workflow_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -130,7 +130,7 @@ class TestWorkflowDeletion: with pytest.raises(DraftWorkflowDeletionError): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -144,7 +144,7 @@ class TestWorkflowDeletion: with pytest.raises(WorkflowInUseError, match="currently in use by app"): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135..32b76c3469 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -64,7 +64,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: db_session_with_containers.commit() return execution - def test_get_node_last_execution_found(self, db_session_with_containers): + def test_get_node_last_execution_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it exists.""" # Arrange tenant_id = str(uuid4()) @@ -110,7 +110,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result.id == expected.id assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - def test_get_node_last_execution_not_found(self, db_session_with_containers): + def test_get_node_last_execution_not_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it doesn't exist.""" # Arrange tenant_id = str(uuid4()) @@ -129,7 +129,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_get_executions_by_workflow_run_empty(self, db_session_with_containers): + def test_get_executions_by_workflow_run_empty(self, db_session_with_containers: Session): """Test getting executions for a workflow run when none exist.""" # Arrange tenant_id = str(uuid4()) @@ -147,7 +147,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == [] - def test_get_execution_by_id_found(self, db_session_with_containers): + def test_get_execution_by_id_found(self, db_session_with_containers: Session): """Test getting execution by ID when it exists.""" # Arrange execution = self._create_execution( @@ -170,7 +170,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is not None assert result.id == execution.id - def test_get_execution_by_id_not_found(self, db_session_with_containers): + def test_get_execution_by_id_not_found(self, db_session_with_containers: Session): """Test getting execution by ID when it doesn't exist.""" # Arrange repository = self._create_repository(db_session_with_containers) @@ -182,7 +182,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_delete_expired_executions(self, db_session_with_containers): + def test_delete_expired_executions(self, db_session_with_containers: Session): """Test deleting expired executions.""" # Arrange tenant_id = str(uuid4()) @@ -248,7 +248,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_2_id not in remaining_ids assert kept_execution_id in remaining_ids - def test_delete_executions_by_app(self, db_session_with_containers): + def test_delete_executions_by_app(self, db_session_with_containers: Session): """Test deleting executions by app.""" # Arrange tenant_id = str(uuid4()) @@ -313,7 +313,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert deleted_2_id not in remaining_ids assert kept_id in remaining_ids - def test_get_expired_executions_batch(self, db_session_with_containers): + def test_get_expired_executions_batch(self, db_session_with_containers: Session): """Test getting expired executions batch for backup.""" # Arrange tenant_id = str(uuid4()) @@ -370,7 +370,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_1.id in result_ids assert old_execution_2.id in result_ids - def test_delete_executions_by_ids(self, db_session_with_containers): + def test_delete_executions_by_ids(self, db_session_with_containers: Session): """Test deleting executions by IDs.""" # Arrange tenant_id = str(uuid4()) @@ -424,7 +424,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: ).all() assert remaining == [] - def test_delete_executions_by_ids_empty_list(self, db_session_with_containers): + def test_delete_executions_by_ids_empty_list(self, db_session_with_containers: Session): """Test deleting executions with empty ID list.""" # Arrange repository = self._create_repository(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 7e5c374b5d..1c8d5969e0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -71,7 +71,7 @@ class TestCleanNotionDocumentTask: yield mock_factory def test_clean_notion_document_task_success( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test successful cleanup of Notion documents with proper database operations. @@ -176,7 +176,7 @@ class TestCleanNotionDocumentTask: # 5. The task completes without errors def test_clean_notion_document_task_dataset_not_found( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior when dataset is not found. @@ -196,7 +196,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.return_value.init_index_processor.assert_not_called() def test_clean_notion_document_task_empty_document_list( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior with empty document list. @@ -240,7 +240,7 @@ class TestCleanNotionDocumentTask: assert args[1] == [] def test_clean_notion_document_task_with_different_index_types( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with different dataset index types. @@ -328,7 +328,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.reset_mock() def test_clean_notion_document_task_with_segments_no_index_node_ids( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments that have no index_node_ids. @@ -411,7 +411,7 @@ class TestCleanNotionDocumentTask: # are properly deleted from the database. def test_clean_notion_document_task_partial_document_cleanup( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with partial document cleanup scenario. @@ -513,7 +513,7 @@ class TestCleanNotionDocumentTask: # The database operations work correctly, isolating only the specified documents. def test_clean_notion_document_task_with_mixed_segment_statuses( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments in different statuses. @@ -603,7 +603,7 @@ class TestCleanNotionDocumentTask: # IndexProcessor verification would require more sophisticated mocking. def test_clean_notion_document_task_continues_when_index_processor_fails( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Index processor failure (e.g. transient billing API error propagated via @@ -707,7 +707,7 @@ class TestCleanNotionDocumentTask: assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 def test_clean_notion_document_task_with_large_number_of_documents( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with a large number of documents and segments. @@ -806,7 +806,7 @@ class TestCleanNotionDocumentTask: # The database efficiently handles large-scale deletions. def test_clean_notion_document_task_with_documents_from_different_tenants( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents from different tenants. @@ -918,7 +918,7 @@ class TestCleanNotionDocumentTask: # Only documents from the target dataset are affected, maintaining tenant separation. def test_clean_notion_document_task_with_documents_in_different_states( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents in different indexing states. @@ -1024,7 +1024,7 @@ class TestCleanNotionDocumentTask: # All documents are deleted regardless of their indexing status. def test_clean_notion_document_task_with_documents_having_metadata( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents that have rich metadata. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 9084667c31..80289c448a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker from sqlalchemy import delete +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client @@ -25,7 +26,7 @@ class TestCreateSegmentToIndexTask: """Integration tests for create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database and Redis before each test to ensure isolation.""" # Clear all test data using fixture session @@ -55,7 +56,7 @@ class TestCreateSegmentToIndexTask: "index_processor": mock_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -102,7 +103,7 @@ class TestCreateSegmentToIndexTask: return account, tenant - def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + def _create_test_dataset_and_document(self, db_session_with_containers: Session, tenant_id, account_id): """ Helper method to create a test dataset and document for testing. @@ -151,7 +152,13 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING + self, + db_session_with_containers: Session, + dataset_id, + document_id, + tenant_id, + account_id, + status=SegmentStatus.WAITING, ): """ Helper method to create a test document segment for testing. @@ -189,7 +196,9 @@ class TestCreateSegmentToIndexTask: return segment - def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful creation of segment to index. @@ -225,7 +234,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_segment_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent segment ID. @@ -246,7 +255,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_invalid_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with invalid status. @@ -277,7 +286,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated dataset. @@ -330,7 +341,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated document. @@ -367,7 +380,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with disabled document. @@ -403,7 +416,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_archived( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with archived document. @@ -439,7 +452,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_indexing_incomplete( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with document that has incomplete indexing. @@ -475,7 +488,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_processor_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of index processor exceptions. @@ -511,7 +524,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_with_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with custom keywords. @@ -543,7 +556,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with different document forms. @@ -586,7 +599,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) def test_create_segment_to_index_performance_timing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing performance and timing. @@ -617,7 +630,7 @@ class TestCreateSegmentToIndexTask: assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test concurrent execution of segment indexing tasks. @@ -654,7 +667,7 @@ class TestCreateSegmentToIndexTask: assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 def test_create_segment_to_index_large_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with large content. @@ -703,7 +716,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_redis_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing when Redis operations fail. @@ -743,7 +756,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 1 def test_create_segment_to_index_database_transaction_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with database transaction handling. @@ -775,7 +788,7 @@ class TestCreateSegmentToIndexTask: assert segment.error is not None def test_create_segment_to_index_metadata_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with metadata validation. @@ -817,7 +830,7 @@ class TestCreateSegmentToIndexTask: assert doc is not None def test_create_segment_to_index_status_transition_flow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test complete status transition flow during indexing. @@ -852,7 +865,7 @@ class TestCreateSegmentToIndexTask: assert segment.indexing_at <= segment.completed_at def test_create_segment_to_index_with_empty_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with empty or minimal content. @@ -894,7 +907,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with special characters and unicode content. @@ -940,7 +953,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_long_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with long keyword lists. @@ -974,7 +987,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with proper tenant isolation. @@ -1017,7 +1030,7 @@ class TestCreateSegmentToIndexTask: assert segment1.tenant_id != segment2.tenant_id def test_create_segment_to_index_with_none_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with None keywords parameter. @@ -1048,7 +1061,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_comprehensive_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Comprehensive integration test covering multiple scenarios. diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 684097851b..a5a3cd10b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -174,11 +175,11 @@ class TestDatasetIndexingTaskIntegration: return dataset, documents - def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: + def _query_document(self, db_session_with_containers: Session, document_id: str) -> Document | None: """Return the latest persisted document state.""" return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) - def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: + def _assert_documents_parsing(self, db_session_with_containers: Session, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" db_session_with_containers.expire_all() for document_id in document_ids: @@ -212,7 +213,9 @@ class TestDatasetIndexingTaskIntegration: assert len(opened) >= 2 assert opened_ids <= closed_ids - def test_legacy_document_indexing_task_still_works(self, db_session_with_containers, patched_external_dependencies): + def test_legacy_document_indexing_task_still_works( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Ensure the legacy task entrypoint still updates parsing status.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -225,7 +228,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_multiple_documents(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_multiple_documents( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Process multiple documents in one batch.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -240,7 +245,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == len(document_ids) self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_with_limit_check(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_with_limit_check( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Reject batches larger than configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. @@ -263,7 +270,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "batch upload limit") def test_batch_processing_sandbox_plan_single_document_only( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Reject multi-document upload under sandbox plan.""" # Arrange @@ -280,7 +287,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() self._assert_documents_error_contains(db_session_with_containers, document_ids, "does not support batch upload") - def test_batch_processing_empty_document_list(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_empty_document_list( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle empty list input without failing.""" # Arrange dataset, _ = self._create_test_dataset_and_documents(db_session_with_containers, document_count=0) @@ -292,7 +301,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_tenant_queue_dispatches_next_task_after_completion( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch the next queued task after current tenant task completes. @@ -337,7 +346,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_not_called() def test_tenant_queue_deletes_running_key_when_no_follow_up_tasks( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Delete tenant running flag when queue has no pending tasks. @@ -362,7 +371,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_called_once() def test_validation_failure_sets_error_status_when_vector_space_at_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Set error status when vector space validation fails before runner phase.""" # Arrange @@ -382,7 +391,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "over the limit") def test_runner_exception_does_not_crash_indexing_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Catch generic runner exceptions without crashing the task.""" # Arrange @@ -397,7 +406,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_document_paused_error_handling(self, db_session_with_containers, patched_external_dependencies): + def test_document_paused_error_handling(self, db_session_with_containers: Session, patched_external_dependencies): """Handle DocumentIsPausedError and keep persisted state consistent.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -424,7 +433,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() def test_tenant_queue_error_handling_still_processes_next_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Even on current task failure, enqueue the next waiting tenant task. @@ -491,7 +500,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_all_opened_sessions_closed(session_close_tracker) def test_multiple_documents_with_mixed_success_and_failure( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Process only existing documents when request includes missing ids.""" # Arrange @@ -508,7 +517,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, existing_ids) def test_tenant_queue_dispatches_up_to_concurrency_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch only up to configured concurrency under queued backlog burst. @@ -543,7 +552,7 @@ class TestDatasetIndexingTaskIntegration: assert task_dispatch_spy.apply_async.call_count == concurrency_limit assert set_waiting_spy.call_count == concurrency_limit - def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies): + def test_task_queue_fifo_ordering(self, db_session_with_containers: Session, patched_external_dependencies): """Keep FIFO ordering when dispatching next queued tasks. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -576,7 +585,9 @@ class TestDatasetIndexingTaskIntegration: call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {}) assert call_kwargs.get("document_ids") == expected_task["document_ids"] - def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies): + def test_billing_disabled_skips_limit_checks( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Skip limit checks when billing feature is disabled.""" # Arrange large_document_ids = [str(uuid.uuid4()) for _ in range(100)] @@ -595,7 +606,7 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 100 self._assert_documents_parsing(db_session_with_containers, large_document_ids) - def test_complete_workflow_normal_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_normal_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end normal queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -618,7 +629,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_complete_workflow_priority_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_priority_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end priority queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -641,7 +652,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_single_document_processing(self, db_session_with_containers, patched_external_dependencies): + def test_single_document_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process the minimum batch size (single document).""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=1) @@ -655,7 +666,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 1 self._assert_documents_parsing(db_session_with_containers, [document_id]) - def test_document_with_special_characters_in_id(self, db_session_with_containers, patched_external_dependencies): + def test_document_with_special_characters_in_id( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle standard UUID ids with hyphen characters safely.""" # Arrange special_document_id = str(uuid.uuid4()) @@ -670,7 +683,9 @@ class TestDatasetIndexingTaskIntegration: # Assert self._assert_documents_parsing(db_session_with_containers, [special_document_id]) - def test_zero_vector_space_limit_allows_unlimited(self, db_session_with_containers, patched_external_dependencies): + def test_zero_vector_space_limit_allows_unlimited( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Treat vector limit 0 as unlimited and continue indexing.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -689,7 +704,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) def test_negative_vector_space_values_handled_gracefully( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Treat negative vector limits as non-blocking and continue indexing.""" # Arrange @@ -708,7 +723,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_large_document_batch_processing(self, db_session_with_containers, patched_external_dependencies): + def test_large_document_batch_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process a batch exactly at configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 48fec441c5..e4cbb9e589 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -55,7 +56,7 @@ class TestDealDatasetVectorIndexTask: yield mock_factory @pytest.fixture - def account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """Create an account with an owner tenant for testing. Returns a tuple of (account, tenant) where tenant is guaranteed to be non-None. @@ -73,7 +74,7 @@ class TestDealDatasetVectorIndexTask: return account, tenant def test_deal_dataset_vector_index_task_remove_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful removal of dataset vector index. @@ -131,7 +132,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail def test_deal_dataset_vector_index_task_add_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful addition of dataset vector index. @@ -233,7 +234,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_update_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful update of dataset vector index. @@ -337,7 +338,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_dataset_not_found_error( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior when dataset is not found. @@ -357,7 +358,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when no documents exist for the dataset. @@ -389,7 +390,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_segments( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when documents exist but have no segments. @@ -447,7 +448,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_update_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test update action when no documents exist for the dataset. @@ -480,7 +481,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_with_exception_handling( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action with exception handling during processing. @@ -578,7 +579,7 @@ class TestDealDatasetVectorIndexTask: assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with custom index type (QA_INDEX). @@ -656,7 +657,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_default_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with default index type (PARAGRAPH_INDEX). @@ -734,7 +735,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_multiple_documents_processing( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task processing with multiple documents and segments. @@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.load.call_count == 3 def test_deal_dataset_vector_index_task_document_status_transitions( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test document status transitions during task execution. @@ -938,7 +939,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with disabled documents. @@ -1061,7 +1062,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_archived_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with archived documents. @@ -1184,7 +1185,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_incomplete_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with documents that have incomplete indexing status. diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 8a69707b38..f4a71040c1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -11,9 +11,19 @@ import logging from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, Document, DocumentSegment, Tenant +from models import ( + Account, + AccountStatus, + Dataset, + DatasetPermissionEnum, + Document, + DocumentSegment, + Tenant, + TenantStatus, +) from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -37,7 +47,7 @@ class TestDeleteSegmentFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_tenant(self, db_session_with_containers, fake=None): + def _create_test_tenant(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant with realistic data. @@ -49,7 +59,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status=TenantStatus.NORMAL) tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -58,7 +68,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return tenant - def _create_test_account(self, db_session_with_containers, tenant, fake=None): + def _create_test_account(self, db_session_with_containers: Session, tenant, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -75,7 +85,7 @@ class TestDeleteSegmentFromIndexTask: name=fake.name(), email=fake.email(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) account.id = fake.uuid4() @@ -86,7 +96,9 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return account - def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + def _create_test_dataset( + self, db_session_with_containers: Session, tenant: Tenant, account: Account, fake: Faker | None = None + ): """ Helper method to create a test dataset with realistic data. @@ -106,7 +118,7 @@ class TestDeleteSegmentFromIndexTask: dataset.name = f"Test Dataset {fake.word()}" dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" - dataset.permission = "only_me" + dataset.permission = DatasetPermissionEnum.ONLY_ME dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' @@ -122,7 +134,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None, **kwargs): """ Helper method to create a test document with realistic data. @@ -172,7 +184,14 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return document - def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + def _create_test_document_segments( + self, + db_session_with_containers: Session, + document: Document, + account: Account, + count: int = 3, + fake: Faker | None = None, + ): """ Helper method to create test document segments with realistic data. @@ -218,7 +237,9 @@ class TestDeleteSegmentFromIndexTask: return segments @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) - def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + def test_delete_segment_from_index_task_success( + self, mock_index_processor_factory, db_session_with_containers: Session + ): """ Test successful segment deletion from index with comprehensive verification. @@ -267,7 +288,7 @@ class TestDeleteSegmentFromIndexTask: assert call_args[1]["with_keywords"] is True assert call_args[1]["delete_child_chunks"] is True - def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers: Session): """ Test task behavior when dataset is not found. @@ -288,7 +309,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found - def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers: Session): """ Test task behavior when document is not found. @@ -314,7 +335,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document not found - def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers: Session): """ Test task behavior when document is disabled. @@ -342,7 +363,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled - def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers: Session): """ Test task behavior when document is archived. @@ -370,7 +391,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is archived - def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers: Session): """ Test task behavior when document indexing is not completed. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6e03bd9351..6bfb1e1f1e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -13,7 +13,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, DocumentSegment +from models import Account, AccountStatus, Dataset, DocumentSegment, TenantAccountRole, TenantStatus from models import Document as DatasetDocument from models.dataset import DatasetProcessRule from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus @@ -35,7 +35,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -51,24 +51,23 @@ class TestDisableSegmentsFromIndexTask: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) - account.id = fake.uuid4() # monkey-patch attributes for test setup + account.updated_at = fake.date_time_this_year() + account.created_at = fake.date_time_this_year() + account.role = TenantAccountRole.OWNER + account.id = fake.uuid4() account.tenant_id = fake.uuid4() account.type = "normal" - account.role = "owner" - account.created_at = fake.date_time_this_year() - account.updated_at = account.created_at - # Create a tenant for the account from models.account import Tenant tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() @@ -83,7 +82,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None): """ Helper method to create a test dataset with realistic data. @@ -117,7 +116,9 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): + def _create_test_document( + self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None + ): """ Helper method to create a test document with realistic data. @@ -216,7 +217,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None): """ Helper method to create a dataset process rule. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index b6e7e6e5c9..77cd259833 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy import delete, func, select, update +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -162,7 +163,7 @@ class TestDocumentIndexingSyncTask: "indexing_runner": indexing_runner, } - def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None): + def _create_notion_sync_context(self, db_session_with_containers: Session, *, data_source_info: dict | None = None): account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset( db_session_with_containers, @@ -206,7 +207,7 @@ class TestDocumentIndexingSyncTask: "notion_info": notion_info, } - def test_document_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task handles missing document gracefully.""" # Arrange dataset_id = str(uuid4()) @@ -219,7 +220,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_workspace_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_workspace_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -235,7 +236,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_page_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_page_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -251,7 +252,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies): + def test_empty_data_source_info(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) @@ -264,7 +265,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_credential_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task sets document error state when credential is missing.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -284,7 +285,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies): + def test_page_not_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task exits early when notion page is unchanged.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -310,7 +311,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies): + def test_successful_sync_when_page_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test full successful sync flow with SQL state updates and side effects.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -349,7 +350,7 @@ class TestDocumentIndexingSyncTask: assert len(run_documents) == 1 assert getattr(run_documents[0], "id", None) == context["document"].id - def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies): + def test_dataset_not_found_during_cleaning(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task still updates document and reindexes if dataset vanishes before clean.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -376,7 +377,9 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies): + def test_cleaning_error_continues_to_indexing( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that indexing continues when index cleanup fails.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -400,7 +403,9 @@ class TestDocumentIndexingSyncTask: assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_document_paused_error( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that DocumentIsPausedError does not flip document into error state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -418,7 +423,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None - def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_general_error(self, db_session_with_containers: Session, mock_external_dependencies): """Test that indexing errors are persisted to document state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index cf1a8666f3..6c1454b6d8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.document_task import DocumentTask from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( @@ -51,7 +52,7 @@ class TestDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -71,14 +72,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -133,7 +134,7 @@ class TestDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -153,14 +154,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -221,7 +222,9 @@ class TestDocumentIndexingTasks: return dataset, documents - def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with multiple documents. @@ -262,7 +265,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 3 def test_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -286,7 +289,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -332,7 +335,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def test_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -373,7 +376,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test processing documents with mixed initial states. @@ -456,7 +459,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 4 def test_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -518,7 +521,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner"].assert_not_called() def test_document_indexing_task_billing_disabled_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful processing when billing is disabled. @@ -554,7 +557,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of DocumentIsPausedError from IndexingRunner. @@ -597,7 +600,9 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== - def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_old_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test document_indexing_task basic functionality. @@ -619,7 +624,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_normal_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_document_indexing_task basic functionality. @@ -643,7 +648,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_priority_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_document_indexing_task basic functionality. @@ -667,7 +672,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_document_indexing_with_tenant_queue_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with no waiting tasks. @@ -717,7 +722,7 @@ class TestDocumentIndexingTasks: mock_task_func.delay.assert_not_called() def test_document_indexing_with_tenant_queue_with_waiting_tasks( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis. @@ -776,7 +781,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 1 def test_document_indexing_with_tenant_queue_error_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling in _document_indexing_with_tenant_queue using real Redis. @@ -848,7 +853,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 0 def test_document_indexing_with_tenant_queue_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation in _document_indexing_with_tenant_queue using real Redis. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index a9a8c0f30c..208fc1aa1d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,9 +3,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -33,7 +34,7 @@ class TestDocumentIndexingUpdateTask: "runner_instance": runner_instance, } - def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2): + def _create_dataset_document_with_segments(self, db_session_with_containers: Session, *, segment_count: int = 2): fake = Faker() # Account and tenant @@ -41,12 +42,12 @@ class TestDocumentIndexingUpdateTask: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=fake.company(), status="normal") + tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -114,7 +115,7 @@ class TestDocumentIndexingUpdateTask: return dataset, document, node_ids - def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies): + def test_cleans_segments_and_reindexes(self, db_session_with_containers: Session, mock_external_dependencies): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Act @@ -153,7 +154,9 @@ class TestDocumentIndexingUpdateTask: first = run_docs[0] assert getattr(first, "id", None) == document.id - def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies): + def test_clean_error_is_logged_and_indexing_continues( + self, db_session_with_containers: Session, mock_external_dependencies + ): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Force clean to raise; task should continue to indexing @@ -173,7 +176,7 @@ class TestDocumentIndexingUpdateTask: ) assert remaining > 0 - def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found_noop(self, db_session_with_containers: Session, mock_external_dependencies): fake = Faker() # Act with non-existent document id document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 39c58987fd..12440f3e6b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -62,7 +63,7 @@ class TestDuplicateDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -145,7 +146,11 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_segments( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + document_count=3, + segments_per_doc=2, ): """ Helper method to create a test dataset with documents and segments. @@ -197,7 +202,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents, segments def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -287,7 +292,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _test_duplicate_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful duplicate document indexing with multiple documents. @@ -329,7 +334,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 3 def _test_duplicate_document_indexing_task_with_segment_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test duplicate document indexing with existing segments that need cleanup. @@ -379,7 +384,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def _test_duplicate_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -404,7 +409,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["index_processor"].clean.assert_not_called() def test_duplicate_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -450,7 +455,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def _test_duplicate_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -491,7 +496,7 @@ class TestDuplicateDocumentIndexingTasks: assert updated_document.processing_started_at is not None def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -554,7 +559,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for vector space limit. @@ -596,7 +601,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_duplicate_document_indexing_task_with_empty_document_list( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of empty document list. @@ -622,7 +627,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_deprecated_duplicate_document_indexing_task_delegates_to_core( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that deprecated duplicate_document_indexing_task delegates to core function. @@ -655,7 +660,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_duplicate_document_indexing_task with tenant isolation queue. @@ -698,7 +703,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_duplicate_document_indexing_task with tenant isolation queue. @@ -742,7 +747,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant queue wrapper processes next queued tasks. @@ -789,7 +794,7 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_not_called() def test_successful_duplicate_document_indexing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test successful duplicate document indexing flow.""" self._test_duplicate_document_indexing_task_success( @@ -797,7 +802,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when dataset is not found.""" self._test_duplicate_document_indexing_task_dataset_not_found( @@ -805,7 +810,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing with billing enabled and sandbox plan.""" self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( @@ -813,7 +818,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when billing limit is exceeded.""" self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( @@ -821,7 +826,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_runner_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when IndexingRunner raises an error.""" self._test_duplicate_document_indexing_task_indexing_runner_exception( @@ -829,7 +834,7 @@ class TestDuplicateDocumentIndexingTasks: ) def _test_duplicate_document_indexing_task_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" # Arrange @@ -860,7 +865,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_duplicate_document_indexing_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" self._test_duplicate_document_indexing_task_document_is_paused( @@ -868,7 +873,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_cleans_old_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test that duplicate document indexing cleans old segments.""" self._test_duplicate_document_indexing_task_with_segment_cleanup( diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py index 177af266fb..a697878bb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestMailChangeMailTask: "get_email_i18n_service": mock_get_email_i18n_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -72,7 +73,7 @@ class TestMailChangeMailTask: return account def test_send_change_mail_task_success_old_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for old_email phase. @@ -103,7 +104,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_success_new_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for new_email phase. @@ -134,7 +135,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when mail service is not initialized. @@ -159,7 +160,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_not_called() def test_send_change_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when email service raises an exception. @@ -191,7 +192,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email completed notification task execution. @@ -224,7 +225,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when mail service is not initialized. @@ -247,7 +248,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_email.assert_not_called() def test_send_change_mail_completed_notification_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when email service raises an exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 8343711998..8e9da6aaaa 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import delete +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -37,7 +38,7 @@ class TestSendEmailCodeLoginMailTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -71,7 +72,7 @@ class TestSendEmailCodeLoginMailTask: "email_service_instance": mock_email_service_instance, } - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account for testing. @@ -98,7 +99,7 @@ class TestSendEmailCodeLoginMailTask: return account - def _create_test_tenant_and_account(self, db_session_with_containers, fake=None): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant and account for testing. @@ -138,7 +139,7 @@ class TestSendEmailCodeLoginMailTask: return account, tenant def test_send_email_code_login_mail_task_success_english( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in English. @@ -182,7 +183,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_chinese( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in Chinese. @@ -221,7 +222,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_multiple_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending with multiple languages. @@ -261,7 +262,7 @@ class TestSendEmailCodeLoginMailTask: assert call_args[1]["template_context"]["code"] == test_codes[i] def test_send_email_code_login_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when mail service is not initialized. @@ -299,7 +300,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_not_called() def test_send_email_code_login_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when email service raises an exception. @@ -346,7 +347,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_invalid_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with invalid parameters. @@ -388,7 +389,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_called_once() def test_send_email_code_login_mail_task_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with edge cases and boundary conditions. @@ -451,7 +452,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with database integration. @@ -497,7 +498,7 @@ class TestSendEmailCodeLoginMailTask: assert account.status == "active" def test_send_email_code_login_mail_task_redis_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with Redis integration. @@ -541,7 +542,7 @@ class TestSendEmailCodeLoginMailTask: redis_client.delete(cache_key) def test_send_email_code_login_mail_task_error_handling_comprehensive( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error handling for email code login mail task. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 95a867dbb5..f505361727 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from sqlalchemy import delete +from sqlalchemy.orm import Session from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -172,7 +173,9 @@ def _create_workflow_pause_state( db_session_with_containers.commit() -def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): +def test_dispatch_human_input_email_task_integration( + monkeypatch: pytest.MonkeyPatch, db_session_with_containers: Session +): tenant, account = _create_workspace_member(db_session_with_containers) workflow_run_id = str(uuid.uuid4()) workflow_id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py index 1a20b6deec..f8e54ea9e6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from tasks.mail_inner_task import send_inner_email_task @@ -51,7 +52,7 @@ class TestMailInnerTask: }, } - def test_send_inner_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful email sending with valid data. @@ -90,7 +91,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_single_recipient(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_single_recipient( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with single recipient. @@ -126,7 +129,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_empty_substitutions(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_empty_substitutions( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with empty substitutions. @@ -163,7 +168,7 @@ class TestMailInnerTask: ) def test_send_inner_email_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when mail service is not initialized. @@ -193,7 +198,7 @@ class TestMailInnerTask: mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() def test_send_inner_email_template_rendering_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when template rendering fails. @@ -222,7 +227,9 @@ class TestMailInnerTask: # Verify no email service calls due to exception mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() - def test_send_inner_email_service_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_service_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending when email service fails. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index d34828c4b1..c8c7a4d961 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import delete, select +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from libs.email_i18n import EmailType @@ -42,7 +43,7 @@ class TestMailInviteMemberTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" # Clear all test data db_session_with_containers.execute(delete(TenantAccountJoin)) @@ -78,7 +79,7 @@ class TestMailInviteMemberTask: "config": mock_config, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -147,7 +148,7 @@ class TestMailInviteMemberTask: redis_client.setex(cache_key, 24 * 60 * 60, json.dumps(invitation_data)) # 24 hours return token - def _create_pending_account_for_invitation(self, db_session_with_containers, email, tenant): + def _create_pending_account_for_invitation(self, db_session_with_containers: Session, email, tenant): """ Helper method to create a pending account for invitation testing. @@ -185,7 +186,9 @@ class TestMailInviteMemberTask: return account - def test_send_invite_member_mail_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_invite_member_mail_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful invitation email sending with all parameters. @@ -231,7 +234,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_different_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test invitation email sending with different language codes. @@ -263,7 +266,7 @@ class TestMailInviteMemberTask: assert call_args[1]["language_code"] == language def test_send_invite_member_mail_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test behavior when mail service is not initialized. @@ -292,7 +295,7 @@ class TestMailInviteMemberTask: mock_email_service.send_email.assert_not_called() def test_send_invite_member_mail_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when email service raises an exception. @@ -322,7 +325,7 @@ class TestMailInviteMemberTask: assert "Send invite member mail to %s failed" in error_call def test_send_invite_member_mail_template_context_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test template context contains all required fields for email rendering. @@ -368,7 +371,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_integration_with_redis_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test integration with Redis token validation. @@ -407,7 +410,7 @@ class TestMailInviteMemberTask: assert invitation_data["workspace_id"] == tenant.id def test_send_invite_member_mail_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending with special characters in names and workspace names. @@ -449,7 +452,7 @@ class TestMailInviteMemberTask: assert template_context["workspace_name"] == workspace_name def test_send_invite_member_mail_real_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test real database integration with actual invitation flow. @@ -501,7 +504,7 @@ class TestMailInviteMemberTask: assert tenant_join.role == TenantAccountRole.NORMAL def test_send_invite_member_mail_token_lifecycle_management( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test token lifecycle management and validation. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py index e08b099480..176645a4ab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py @@ -11,6 +11,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -44,7 +45,7 @@ class TestMailOwnerTransferTask: "get_email_service": mock_get_email_service, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create test account and tenant for testing. @@ -86,7 +87,9 @@ class TestMailOwnerTransferTask: return account, tenant - def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_owner_transfer_confirm_task_success( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """ Test successful owner transfer confirmation email sending. @@ -127,7 +130,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_owner_transfer_confirm_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test owner transfer confirmation email when mail service is not initialized. @@ -158,7 +161,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_owner_transfer_confirm_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in owner transfer confirmation email. @@ -192,7 +195,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_old_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful old owner transfer notification email sending. @@ -234,7 +237,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email def test_send_old_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test old owner transfer notification email when mail service is not initialized. @@ -265,7 +268,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_old_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in old owner transfer notification email. @@ -299,7 +302,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_new_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful new owner transfer notification email sending. @@ -338,7 +341,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_new_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test new owner transfer notification email when mail service is not initialized. @@ -367,7 +370,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_new_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in new owner transfer notification email. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py index cced6f7780..071971f324 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist @@ -35,7 +36,7 @@ class TestMailRegisterTask: "get_email_service": mock_get_email_service, } - def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_success(self, db_session_with_containers: Session, mock_mail_dependencies): """Test successful email registration mail sending.""" fake = Faker() language = "en-US" @@ -56,7 +57,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test email registration task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -66,7 +67,9 @@ class TestMailRegisterTask: mock_mail_dependencies["get_email_service"].assert_not_called() mock_mail_dependencies["email_service"].send_email.assert_not_called() - def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_exception_handling( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """Test email registration task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") @@ -79,7 +82,7 @@ class TestMailRegisterTask: mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) def test_send_email_register_mail_task_when_account_exist_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test successful email registration mail sending when account exists.""" fake = Faker() @@ -105,7 +108,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_when_account_exist_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -118,7 +121,7 @@ class TestMailRegisterTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_email_register_mail_task_when_account_exist_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index f01fcc1742..5eea985fdc 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,12 +4,13 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from flask import Flask from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Pipeline from models.workflow import Workflow from tasks.rag_pipeline.priority_rag_pipeline_run_task import ( @@ -69,14 +70,14 @@ class TestRagPipelineRunTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -725,7 +726,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test successful run_single_rag_pipeline_task execution. @@ -760,7 +761,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -805,7 +806,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with non-existent database entities. diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index b43b622870..03c02ea341 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -3,6 +3,7 @@ from unittest.mock import ANY, call, patch import pytest from sqlalchemy import delete, func, select +from sqlalchemy.orm import Session from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType @@ -117,7 +118,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: class TestDeleteDraftVariablesBatch: - def test_delete_draft_variables_batch_success(self, db_session_with_containers): + def test_delete_draft_variables_batch_success(self, db_session_with_containers: Session): """Test successful deletion of draft variables in batches.""" _, app1 = _create_tenant_and_app(db_session_with_containers) _, app2 = _create_tenant_and_app(db_session_with_containers) @@ -137,7 +138,7 @@ class TestDeleteDraftVariablesBatch: assert app1_remaining_count == 0 assert app2_remaining_count == 100 - def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers): + def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers: Session): """Test deletion when no draft variables exist for the app.""" result = delete_draft_variables_batch(str(uuid.uuid4()), 1000) @@ -176,7 +177,7 @@ class TestDeleteDraftVariableOffloadData: """Test the Offload data cleanup functionality.""" @patch("extensions.ext_storage.storage") - def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers): + def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers: Session): """Test successful deletion of offload data.""" tenant, app = _create_tenant_and_app(db_session_with_containers) offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3) diff --git a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py index 34a1941c39..6365207661 100644 --- a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py +++ b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py @@ -1,12 +1,14 @@ from pathlib import Path +import pytest + from extensions.storage.opendal_storage import OpenDALStorage class TestOpenDALFsDefaultRoot: """Test that OpenDALStorage with scheme='fs' works correctly when no root is provided.""" - def test_fs_without_root_uses_default(self, tmp_path, monkeypatch): + def test_fs_without_root_uses_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When no root is specified, the default 'storage' should be used and passed to the Operator.""" # Change to tmp_path so the default "storage" dir is created there monkeypatch.chdir(tmp_path) @@ -25,7 +27,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_default_root.txt") - def test_fs_with_explicit_root(self, tmp_path): + def test_fs_with_explicit_root(self, tmp_path: Path): """When root is explicitly provided, it should be used.""" custom_root = str(tmp_path / "custom_storage") storage = OpenDALStorage(scheme="fs", root=custom_root) @@ -38,7 +40,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_explicit_root.txt") - def test_fs_with_env_var_root(self, tmp_path, monkeypatch): + def test_fs_with_env_var_root(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs.""" env_root = str(tmp_path / "env_storage") monkeypatch.setenv("OPENDAL_FS_ROOT", env_root) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b00d827e37..6402e7da2b 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -175,7 +175,7 @@ class TestWorkflowPauseIntegration: """Comprehensive integration tests for workflow pause functionality.""" @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers): + def setup_test_data(self, db_session_with_containers: Session): """Set up test data for each test method using TestContainers.""" # Create test tenant and account diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 19a41b6186..a5086b4c5d 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): """Test class for JavaScript code executor functionality.""" - def test_javascript_plain(self, flask_app_with_containers): + def test_javascript_plain(self, flask_app_with_containers: Flask): """Test basic JavaScript code execution with console.log output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result_message == "Hello World\n" - def test_javascript_json(self, flask_app_with_containers): + def test_javascript_json(self, flask_app_with_containers: Flask): """Test JavaScript code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result == '{"Hello":"World"}\n' - def test_javascript_with_code_template(self, flask_app_with_containers): + def test_javascript_with_code_template(self, flask_app_with_containers: Flask): """Test JavaScript workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports JavascriptCodeProvider, _ = self.javascript_imports @@ -37,7 +39,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_javascript_get_runner_script(self, flask_app_with_containers): + def test_javascript_get_runner_script(self, flask_app_with_containers: Flask): """Test JavaScript template transformer runner script generation""" _, NodeJsTemplateTransformer = self.javascript_imports diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index ddb079f00c..8b4c3c3d4a 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -1,12 +1,14 @@ import base64 +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJinja2CodeExecutor(CodeExecutorTestMixin): """Test class for Jinja2 code executor functionality.""" - def test_jinja2(self, flask_app_with_containers): + def test_jinja2(self, flask_app_with_containers: Flask): """Test basic Jinja2 template execution with variable substitution""" CodeExecutor, CodeLanguage = self.code_executor_imports _, Jinja2TemplateTransformer = self.jinja2_imports @@ -25,7 +27,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == "<>Hello World<>\n" - def test_jinja2_with_code_template(self, flask_app_with_containers): + def test_jinja2_with_code_template(self, flask_app_with_containers: Flask): """Test Jinja2 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -34,7 +36,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "Hello World"} - def test_jinja2_get_runner_script(self, flask_app_with_containers): + def test_jinja2_get_runner_script(self, flask_app_with_containers: Flask): """Test Jinja2 template transformer runner script generation""" _, Jinja2TemplateTransformer = self.jinja2_imports @@ -43,7 +45,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - def test_jinja2_template_with_special_characters(self, flask_app_with_containers): + def test_jinja2_template_with_special_characters(self, flask_app_with_containers: Flask): """ Test that templates with special characters (quotes, newlines) render correctly. This is a regression test for issue #26818 where textarea pre-fill values diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py index 6d93df2472..0de41e1312 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestPython3CodeExecutor(CodeExecutorTestMixin): """Test class for Python3 code executor functionality.""" - def test_python3_plain(self, flask_app_with_containers): + def test_python3_plain(self, flask_app_with_containers: Flask): """Test basic Python3 code execution with print output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == "Hello World\n" - def test_python3_json(self, flask_app_with_containers): + def test_python3_json(self, flask_app_with_containers: Flask): """Test Python3 code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == '{"Hello": "World"}\n' - def test_python3_with_code_template(self, flask_app_with_containers): + def test_python3_with_code_template(self, flask_app_with_containers: Flask): """Test Python3 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports Python3CodeProvider, _ = self.python3_imports @@ -37,7 +39,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_python3_get_runner_script(self, flask_app_with_containers): + def test_python3_get_runner_script(self, flask_app_with_containers: Flask): """Test Python3 template transformer runner script generation""" _, Python3TemplateTransformer = self.python3_imports diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index d3e864a75a..78413a0798 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -67,7 +67,7 @@ class TestActivateCheckApi: assert response["data"]["email"] == "invitee@example.com" @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_invalid_invitation_token(self, mock_get_invitation, app): + def test_check_invalid_invitation_token(self, mock_get_invitation, app: Flask): """ Test checking invalid invitation token. @@ -227,7 +227,7 @@ class TestActivateApi: mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_activation_with_invalid_token(self, mock_get_invitation, app): + def test_activation_with_invalid_token(self, mock_get_invitation, app: Flask): """ Test account activation with invalid token. diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index b7bc73da5f..7b2c7569fe 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -140,7 +140,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") - def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending blocked by IP rate limit. @@ -160,7 +160,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.login.AccountService.get_user_through_email") - def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app): + def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending to frozen account. @@ -353,7 +353,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app): + def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app: Flask): """ Test email code login with invalid token. @@ -375,7 +375,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app): + def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app: Flask): """ Test email code login with mismatched email. @@ -397,7 +397,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app): + def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app: Flask): """ Test email code login with incorrect code. diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index d089be8905..5284f29eed 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -9,7 +9,7 @@ This module tests the core authentication endpoints including: """ import base64 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from flask import Flask @@ -52,12 +52,12 @@ class TestLoginApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(LoginApi, "/login") return app.test_client() @@ -97,7 +97,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -141,14 +141,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") def test_successful_login_with_valid_invitation( self, - mock_reset_rate_limit, + mock_reset_rate_limit: Mock, mock_login, mock_get_tenants, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -188,7 +188,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login rejection when rate limit is exceeded. @@ -216,7 +216,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @patch("controllers.console.auth.login.BillingService.is_email_in_freeze") - def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app): + def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask): """ Test login rejection for frozen accounts. @@ -253,7 +253,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, ): """ Test login failure with invalid credentials. @@ -290,7 +290,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( - self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app + self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask ): """ Test login rejection for banned accounts. @@ -328,14 +328,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.FeatureService.get_system_features") def test_login_fails_when_no_workspace_and_limit_exceeded( self, - mock_get_features, - mock_get_tenants, - mock_authenticate, - mock_get_invitation, - mock_is_rate_limit, - mock_db, - app, - mock_account, + mock_get_features: MagicMock, + mock_get_tenants: MagicMock, + mock_authenticate: MagicMock, + mock_get_invitation: MagicMock, + mock_is_rate_limit: MagicMock, + mock_db: MagicMock, + app: Flask, + mock_account: MagicMock, ): """ Test login failure when user has no workspace and workspace limit exceeded. @@ -367,7 +367,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login failure when invitation email doesn't match login email. @@ -491,7 +491,7 @@ class TestLogoutApi: @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") def test_successful_logout( - self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account + self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account ): """ Test successful logout flow. @@ -518,7 +518,7 @@ class TestLogoutApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.current_account_with_tenant") @patch("controllers.console.auth.login.flask_login") - def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app): + def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask): """ Test logout for anonymous (not logged in) user. diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index d010f60866..15c95f6b94 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -28,12 +28,12 @@ class TestRefreshTokenApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(RefreshTokenApi, "/refresh-token") return app.test_client() diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index 810f1b94fc..defa9064fd 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -49,7 +49,7 @@ class TestPartnerTenants: mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} - def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_success(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test successful partner tenants bindings sync.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -79,7 +79,7 @@ class TestPartnerTenants: mock_account.id, "partner-key-123", click_id ) - def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_invalid_partner_key_base64(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that invalid base64 partner_key raises BadRequest.""" # Arrange invalid_partner_key = "invalid-base64-!@#$" @@ -104,7 +104,7 @@ class TestPartnerTenants: resource.put(invalid_partner_key) assert "Invalid partner_key" in str(exc_info.value) - def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_missing_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that missing click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -128,7 +128,9 @@ class TestPartnerTenants: with pytest.raises(BadRequest): resource.put(partner_key_encoded) - def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_billing_service_json_decode_error( + self, app: Flask, mock_account, mock_billing_service, mock_decorators + ): """Test handling of billing service JSON decode error. When billing service returns non-200 status code with invalid JSON response, @@ -174,7 +176,7 @@ class TestPartnerTenants: assert isinstance(exc_info.value, json.JSONDecodeError) assert "Expecting value" in str(exc_info.value) - def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -199,7 +201,7 @@ class TestPartnerTenants: resource.put(partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_partner_key_after_decode(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty partner_key after decode raises BadRequest.""" # Arrange # Base64 encode an empty string @@ -225,7 +227,7 @@ class TestPartnerTenants: resource.put(empty_partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_user_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty user id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 032b1377a4..99a90f3b67 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -17,7 +17,7 @@ def app(): return app -def test_parse_openapi_to_tool_bundle_operation_id(app): +def test_parse_openapi_to_tool_bundle_operation_id(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -63,7 +63,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): assert tool_bundles[2].operation_id == "createResource" -def test_parse_openapi_to_tool_bundle_properties_all_of(app): +def test_parse_openapi_to_tool_bundle_properties_all_of(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -118,7 +118,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} -def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): +def test_parse_openapi_to_tool_bundle_default_value_type_casting(app: Flask): """ Test that default values are properly cast to match parameter types. This addresses the issue where array default values like [] cause validation errors diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py index 762d7b9090..e7f7cabecd 100644 --- a/api/tests/unit_tests/services/controller_api.py +++ b/api/tests/unit_tests/services/controller_api.py @@ -146,7 +146,7 @@ class ControllerApiTestDataFactory: return app @staticmethod - def create_api_instance(app): + def create_api_instance(app: Flask): """ Create a Flask-RESTX API instance. @@ -160,7 +160,12 @@ class ControllerApiTestDataFactory: return api @staticmethod - def create_test_client(app, api, resource_class, route): + def create_test_client( + app: Flask, + api: Api, + resource_class: type, + route: str, + ): """ Create a Flask test client with a resource registered. @@ -302,7 +307,7 @@ class TestDatasetListApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """ Create Flask-RESTX API instance. @@ -311,7 +316,7 @@ class TestDatasetListApi: return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """ Create test client with DatasetListApi registered. @@ -472,12 +477,12 @@ class TestDatasetApiGet: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with DatasetApi registered.""" return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets/") @@ -588,12 +593,12 @@ class TestDatasetApiCreate: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with DatasetApi registered.""" return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets") @@ -681,12 +686,12 @@ class TestHitTestingApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with HitTestingApi registered.""" return ControllerApiTestDataFactory.create_test_client( app, api, HitTestingApi, "/datasets//hit-testing" @@ -799,12 +804,12 @@ class TestExternalDatasetApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client_list(self, app, api): + def client_list(self, app: Flask, api: Api): """Create test client for external knowledge API list endpoint.""" return ControllerApiTestDataFactory.create_test_client( app, api, ExternalApiTemplateListApi, "/datasets/external-knowledge-api" diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index fe4c10329e..4a4742adcf 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -4156,8 +4156,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesResponse export type DeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdData = { body?: never path: { - app_id: string variable_id: string + app_id: string } query?: never url: '/apps/{app_id}/workflows/draft/variables/{variable_id}' @@ -4210,8 +4210,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesByVariableIdResponse export type PatchAppsByAppIdWorkflowsDraftVariablesByVariableIdData = { body: WorkflowDraftVariableUpdatePayload path: { - app_id: string variable_id: string + app_id: string } query?: never url: '/apps/{app_id}/workflows/draft/variables/{variable_id}' diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index dcaeaed246..9798d22cc0 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -2980,8 +2980,8 @@ export const zGetAppsByAppIdWorkflowsDraftVariablesQuery = z.object({ export const zGetAppsByAppIdWorkflowsDraftVariablesResponse = zWorkflowDraftVariableListWithoutValue export const zDeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({ - app_id: z.string(), variable_id: z.string(), + app_id: z.string(), }) /** @@ -3006,8 +3006,8 @@ export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdBody = zWorkflowDraftVariableUpdatePayload export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({ - app_id: z.string(), variable_id: z.string(), + app_id: z.string(), }) /** diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 89a68593b7..61d380d686 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -255,6 +255,7 @@ export type ProcessRule = { } export type RetrievalModel = { + metadata_filtering_conditions?: MetadataFilteringCondition reranking_enable: boolean reranking_mode?: string | null reranking_model?: RerankingModel @@ -312,6 +313,11 @@ export type Rule = { subchunk_segmentation?: Segmentation } +export type MetadataFilteringCondition = { + conditions?: Array | null + logical_operator?: 'and' | 'or' | null +} + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null @@ -405,6 +411,30 @@ export type Segmentation = { separator?: string } +export type Condition = { + comparison_operator: + | 'contains' + | 'not contains' + | 'start with' + | 'end with' + | 'is' + | 'is not' + | 'empty' + | 'not empty' + | 'in' + | 'not in' + | '=' + | '≠' + | '>' + | '<' + | '≥' + | '≤' + | 'before' + | 'after' + name: string + value?: unknown +} + export type WeightKeywordSetting = { keyword_weight: number } @@ -1174,8 +1204,8 @@ export type PatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse export type DeleteDatasetsByDatasetIdDocumentsByDocumentIdData = { body?: never path: { - dataset_id: string document_id: string + dataset_id: string } query?: never url: '/datasets/{dataset_id}/documents/{document_id}' diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index 2ac2cbfd1f..76491c52a0 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -392,6 +392,46 @@ export const zProcessRule = z.object({ rules: zRule.optional(), }) +/** + * Condition + * + * Condition detail + */ +export const zCondition = z.object({ + comparison_operator: z.enum([ + 'contains', + 'not contains', + 'start with', + 'end with', + 'is', + 'is not', + 'empty', + 'not empty', + 'in', + 'not in', + '=', + '≠', + '>', + '<', + '≥', + '≤', + 'before', + 'after', + ]), + name: z.string(), + value: z.unknown().optional(), +}) + +/** + * MetadataFilteringCondition + * + * Metadata Filtering Condition. + */ +export const zMetadataFilteringCondition = z.object({ + conditions: z.array(zCondition).nullish(), + logical_operator: z.enum(['and', 'or']).nullish().default('and'), +}) + /** * WeightKeywordSetting */ @@ -421,6 +461,7 @@ export const zWeightModel = z.object({ * RetrievalModel */ export const zRetrievalModel = z.object({ + metadata_filtering_conditions: zMetadataFilteringCondition.optional(), reranking_enable: z.boolean(), reranking_mode: z.string().nullish(), reranking_model: zRerankingModel.optional(), @@ -925,8 +966,8 @@ export const zPatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse = z.r ) export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdPath = z.object({ - dataset_id: z.string(), document_id: z.string(), + dataset_id: z.string(), }) /** diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index f491c1e3f9..e3791e295c 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -325,8 +325,37 @@ export type WorkflowRunResponse = { workflow_id: string } +export type Condition = { + comparison_operator: + | 'contains' + | 'not contains' + | 'start with' + | 'end with' + | 'is' + | 'is not' + | 'empty' + | 'not empty' + | 'in' + | 'not in' + | '=' + | '≠' + | '>' + | '<' + | '≥' + | '≤' + | 'before' + | 'after' + name: string + value?: unknown +} + export type DatasetPermissionEnum = 'only_me' | 'all_team_members' | 'partial_members' +export type MetadataFilteringCondition = { + conditions?: Array | null + logical_operator?: 'and' | 'or' | null +} + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null @@ -339,6 +368,7 @@ export type RetrievalMethod | 'keyword_search' export type RetrievalModel = { + metadata_filtering_conditions?: MetadataFilteringCondition reranking_enable: boolean reranking_mode?: string | null reranking_model?: RerankingModel @@ -1833,8 +1863,8 @@ export type GetDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdData = body?: never path: { segment_id: string - dataset_id: string document_id: string + dataset_id: string } query?: never url: '/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 2c2400c0cb..6feacbdead 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -326,11 +326,51 @@ export const zWorkflowRunResponse = z.object({ workflow_id: z.string(), }) +/** + * Condition + * + * Condition detail + */ +export const zCondition = z.object({ + comparison_operator: z.enum([ + 'contains', + 'not contains', + 'start with', + 'end with', + 'is', + 'is not', + 'empty', + 'not empty', + 'in', + 'not in', + '=', + '≠', + '>', + '<', + '≥', + '≤', + 'before', + 'after', + ]), + name: z.string(), + value: z.unknown().optional(), +}) + /** * DatasetPermissionEnum */ export const zDatasetPermissionEnum = z.enum(['only_me', 'all_team_members', 'partial_members']) +/** + * MetadataFilteringCondition + * + * Metadata Filtering Condition. + */ +export const zMetadataFilteringCondition = z.object({ + conditions: z.array(zCondition).nullish(), + logical_operator: z.enum(['and', 'or']).nullish().default('and'), +}) + /** * RerankingModel */ @@ -378,6 +418,7 @@ export const zWeightModel = z.object({ * RetrievalModel */ export const zRetrievalModel = z.object({ + metadata_filtering_conditions: zMetadataFilteringCondition.optional(), reranking_enable: z.boolean(), reranking_mode: z.string().nullish(), reranking_model: zRerankingModel.optional(), @@ -1082,8 +1123,8 @@ export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdR export const zGetDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdPath = z.object({ segment_id: z.string(), - dataset_id: z.string(), document_id: z.string(), + dataset_id: z.string(), }) /**