diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index f73e2da54e..b9e876c906 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -32,12 +32,7 @@ class TagBindingPayload(BaseModel): class TagBindingRemovePayload(BaseModel): - tag_id: str = Field(description="Tag ID to remove") - target_id: str = Field(description="Target ID to unbind tag from") - type: TagType = Field(description="Tag type") - - -class TagBindingItemDeletePayload(BaseModel): + tag_ids: list[str] = Field(description="Tag IDs to remove", min_length=1) target_id: str = Field(description="Target ID to unbind tag from") type: TagType = Field(description="Tag type") @@ -75,7 +70,6 @@ register_schema_models( TagBasePayload, TagBindingPayload, TagBindingRemovePayload, - TagBindingItemDeletePayload, TagListQueryParam, TagResponse, ) @@ -184,13 +178,13 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]: return {"result": "success"}, 200 -def _remove_tag_binding() -> tuple[dict[str, str], int]: +def _remove_tag_bindings() -> tuple[dict[str, str], int]: _require_tag_binding_edit_permission() payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) TagService.delete_tag_binding( TagBindingDeletePayload( - tag_id=payload.tag_id, + tag_ids=payload.tag_ids, target_id=payload.target_id, type=payload.type, ) @@ -211,54 +205,15 @@ class TagBindingCollectionApi(Resource): return _create_tag_bindings() -@console_ns.route("/tag-bindings/") -class TagBindingItemApi(Resource): - """Canonical item resource for tag binding deletion.""" - - @console_ns.doc("delete_tag_binding") - @console_ns.doc(params={"id": "Tag ID"}) - @console_ns.expect(console_ns.models[TagBindingItemDeletePayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def delete(self, id): - _require_tag_binding_edit_permission() - payload = TagBindingItemDeletePayload.model_validate(console_ns.payload or {}) - TagService.delete_tag_binding( - TagBindingDeletePayload( - tag_id=str(id), - target_id=payload.target_id, - type=payload.type, - ) - ) - return {"result": "success"}, 200 - - -@console_ns.route("/tag-bindings/create") -class DeprecatedTagBindingCreateApi(Resource): - """Deprecated verb-based alias for tag binding creation.""" - - @console_ns.doc("create_tag_binding_deprecated") - @console_ns.doc(deprecated=True) - @console_ns.doc(description="Deprecated legacy alias. Use POST /tag-bindings instead.") - @console_ns.expect(console_ns.models[TagBindingPayload.__name__]) - @setup_required - @login_required - @account_initialization_required - def post(self): - return _create_tag_bindings() - - @console_ns.route("/tag-bindings/remove") -class DeprecatedTagBindingRemoveApi(Resource): - """Deprecated verb-based alias for tag binding deletion.""" +class TagBindingRemoveApi(Resource): + """Batch resource for tag binding deletion.""" - @console_ns.doc("delete_tag_binding_deprecated") - @console_ns.doc(deprecated=True) - @console_ns.doc(description="Deprecated legacy alias. Use DELETE /tag-bindings/{id} instead.") + @console_ns.doc("remove_tag_bindings") + @console_ns.doc(description="Remove one or more tag bindings from a target.") @console_ns.expect(console_ns.models[TagBindingRemovePayload.__name__]) @setup_required @login_required @account_initialization_required def post(self): - return _remove_tag_binding() + return _remove_tag_bindings() diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 76519cad0a..3eb773fa7c 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,7 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal -from pydantic import BaseModel, Field, TypeAdapter, field_validator +from pydantic import BaseModel, Field, TypeAdapter, field_validator, model_validator from werkzeug.exceptions import Forbidden, NotFound import services @@ -100,9 +100,27 @@ class TagBindingPayload(BaseModel): class TagUnbindingPayload(BaseModel): - tag_id: str + """Accept the legacy single-tag Service API payload while exposing a normalized tag_ids list internally.""" + + tag_ids: list[str] = Field(default_factory=list) + tag_id: str | None = None target_id: str + @model_validator(mode="before") + @classmethod + def normalize_legacy_tag_id(cls, data: object) -> object: + if not isinstance(data, dict): + return data + if not data.get("tag_ids") and data.get("tag_id"): + return {**data, "tag_ids": [data["tag_id"]]} + return data + + @model_validator(mode="after") + def validate_tag_ids(self) -> "TagUnbindingPayload": + if not self.tag_ids: + raise ValueError("Tag IDs is required.") + return self + class DatasetListQuery(BaseModel): page: int = Field(default=1, description="Page number") @@ -601,11 +619,11 @@ class DatasetTagBindingApi(DatasetApiResource): @service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__]) - @service_api_ns.doc("unbind_dataset_tag") - @service_api_ns.doc(description="Unbind a tag from a dataset") + @service_api_ns.doc("unbind_dataset_tags") + @service_api_ns.doc(description="Unbind tags from a dataset") @service_api_ns.doc( responses={ - 204: "Tag unbound successfully", + 204: "Tags unbound successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } @@ -618,7 +636,7 @@ class DatasetTagUnbindingApi(DatasetApiResource): payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag_binding( - TagBindingDeletePayload(tag_id=payload.tag_id, target_id=payload.target_id, type=TagType.KNOWLEDGE) + TagBindingDeletePayload(tag_ids=payload.tag_ids, target_id=payload.target_id, type=TagType.KNOWLEDGE) ) return "", 204 diff --git a/api/extensions/ext_session_factory.py b/api/extensions/ext_session_factory.py index 0eb43d66f4..e19ccd11e5 100644 --- a/api/extensions/ext_session_factory.py +++ b/api/extensions/ext_session_factory.py @@ -1,7 +1,9 @@ +from flask import Flask + from core.db.session_factory import configure_session_factory from extensions.ext_database import db -def init_app(app): +def init_app(app: Flask): with app.app_context(): configure_session_factory(db.engine) diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py index ece061db67..6283dbb986 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/src/dify_vdb_tidb_on_qdrant/tidb_service.py @@ -246,8 +246,18 @@ class TidbService: userPrefix = item["userPrefix"] if state == "ACTIVE" and len(userPrefix) > 0: cluster_info = tidb_serverless_list_map[item["clusterId"]] - cluster_info.status = TidbAuthBindingStatus.ACTIVE cluster_info.account = f"{userPrefix}.root" + if not cluster_info.qdrant_endpoint: + cluster_info.qdrant_endpoint = TidbService.extract_qdrant_endpoint( + item + ) or TidbService.fetch_qdrant_endpoint(api_url, public_key, private_key, item["clusterId"]) + if cluster_info.qdrant_endpoint: + cluster_info.status = TidbAuthBindingStatus.ACTIVE + else: + logger.warning( + "Cluster %s is ACTIVE but qdrant endpoint is not ready; will retry later", + item["clusterId"], + ) db.session.add(cluster_info) db.session.commit() else: diff --git a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py index c1ffbacbbc..20a42f6cc3 100644 --- a/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py +++ b/api/providers/vdb/vdb-tidb-on-qdrant/tests/unit_tests/test_tidb_service.py @@ -1,8 +1,11 @@ +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from dify_vdb_tidb_on_qdrant.tidb_service import TidbService +from models.enums import TidbAuthBindingStatus + class TestExtractQdrantEndpoint: """Unit tests for TidbService.extract_qdrant_endpoint.""" @@ -216,3 +219,86 @@ class TestBatchCreateEdgeCases: private_key="priv", region="us-east-1", ) + + +class TestBatchUpdateTidbServerlessClusterStatus: + """Verify that status updates only expose clusters after qdrant endpoint is ready.""" + + @patch("dify_vdb_tidb_on_qdrant.tidb_service.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + def test_sets_active_when_batch_response_contains_endpoint(self, mock_http, mock_db): + binding = SimpleNamespace( + cluster_id="c-1", + status=TidbAuthBindingStatus.CREATING, + account="root", + qdrant_endpoint=None, + ) + mock_http.get.return_value = MagicMock( + status_code=200, + json=lambda: { + "clusters": [ + { + "clusterId": "c-1", + "state": "ACTIVE", + "userPrefix": "pfx", + "endpoints": {"public": {"host": "gw.tidbcloud.com"}}, + } + ] + }, + ) + + TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv") + + assert binding.account == "pfx.root" + assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com" + assert binding.status == TidbAuthBindingStatus.ACTIVE + mock_db.session.add.assert_called_once_with(binding) + mock_db.session.commit.assert_called_once() + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value="https://qdrant-gw.tidbcloud.com") + @patch("dify_vdb_tidb_on_qdrant.tidb_service.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + def test_fetches_endpoint_when_batch_response_omits_it(self, mock_http, mock_db, mock_fetch_endpoint): + binding = SimpleNamespace( + cluster_id="c-1", + status=TidbAuthBindingStatus.CREATING, + account="root", + qdrant_endpoint=None, + ) + mock_http.get.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]}, + ) + + TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv") + + assert binding.account == "pfx.root" + assert binding.qdrant_endpoint == "https://qdrant-gw.tidbcloud.com" + assert binding.status == TidbAuthBindingStatus.ACTIVE + mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1") + mock_db.session.add.assert_called_once_with(binding) + mock_db.session.commit.assert_called_once() + + @patch.object(TidbService, "fetch_qdrant_endpoint", return_value=None) + @patch("dify_vdb_tidb_on_qdrant.tidb_service.db") + @patch("dify_vdb_tidb_on_qdrant.tidb_service._tidb_http_client") + def test_keeps_creating_when_endpoint_is_not_ready(self, mock_http, mock_db, mock_fetch_endpoint): + binding = SimpleNamespace( + cluster_id="c-1", + status=TidbAuthBindingStatus.CREATING, + account="root", + qdrant_endpoint=None, + ) + mock_http.get.return_value = MagicMock( + status_code=200, + json=lambda: {"clusters": [{"clusterId": "c-1", "state": "ACTIVE", "userPrefix": "pfx", "endpoints": {}}]}, + ) + + TidbService.batch_update_tidb_serverless_cluster_status([binding], "proj", "url", "iam", "pub", "priv") + + assert binding.account == "pfx.root" + assert binding.qdrant_endpoint is None + assert binding.status == TidbAuthBindingStatus.CREATING + mock_fetch_endpoint.assert_called_once_with("url", "pub", "priv", "c-1") + mock_db.session.add.assert_called_once_with(binding) + mock_db.session.commit.assert_called_once() diff --git a/api/pyproject.toml b/api/pyproject.toml index bcbde0842b..69add5c68d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -174,7 +174,7 @@ dev = [ # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.62.0", + "pyrefly>=0.64.0", "xinference-client>=2.7.0", ] diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 1882c855ea..8043a99be1 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -1,9 +1,11 @@ import uuid +from typing import cast import sqlalchemy as sa from flask_login import current_user from pydantic import BaseModel, Field -from sqlalchemy import func, select +from sqlalchemy import delete, func, select +from sqlalchemy.engine import CursorResult from werkzeug.exceptions import NotFound from extensions.ext_database import db @@ -29,7 +31,7 @@ class TagBindingCreatePayload(BaseModel): class TagBindingDeletePayload(BaseModel): - tag_id: str + tag_ids: list[str] = Field(min_length=1) target_id: str type: TagType @@ -178,13 +180,18 @@ class TagService: @staticmethod def delete_tag_binding(payload: TagBindingDeletePayload): TagService.check_target_exists(payload.type, payload.target_id) - tag_binding = db.session.scalar( - select(TagBinding) - .where(TagBinding.target_id == payload.target_id, TagBinding.tag_id == payload.tag_id) - .limit(1) + result = cast( + CursorResult, + db.session.execute( + delete(TagBinding).where( + TagBinding.target_id == payload.target_id, + TagBinding.tag_id.in_(payload.tag_ids), + TagBinding.tenant_id == current_user.current_tenant_id, + ) + ), ) - if tag_binding: - db.session.delete(tag_binding) + + if result.rowcount: db.session.commit() @staticmethod diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index 66a25e5daf..b4482674da 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -433,7 +433,7 @@ def flask_app_with_containers(set_up_containers_and_env) -> Flask: @pytest.fixture -def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, None, None]: +def flask_req_ctx_with_containers(flask_app_with_containers: Flask) -> Generator[None, None, None]: """ Request context fixture for containerized Flask application. @@ -454,7 +454,7 @@ def flask_req_ctx_with_containers(flask_app_with_containers) -> Generator[None, @pytest.fixture -def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskClient, None, None]: +def test_client_with_containers(flask_app_with_containers: Flask) -> Generator[FlaskClient, None, None]: """ Test client fixture for containerized Flask application. @@ -475,7 +475,7 @@ def test_client_with_containers(flask_app_with_containers) -> Generator[FlaskCli @pytest.fixture -def db_session_with_containers(flask_app_with_containers) -> Generator[Session, None, None]: +def db_session_with_containers(flask_app_with_containers: Flask) -> Generator[Session, None, None]: """ Database session fixture for containerized testing. diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 18755ef012..bb737754a1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -69,7 +70,7 @@ def _unwrap(func): class TestCompletionEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_completion_create_payload(self): @@ -86,7 +87,7 @@ class TestCompletionEndpoints: ) assert payload.query == "hi" - def test_completion_api_success(self, app, monkeypatch): + def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -116,7 +117,7 @@ class TestCompletionEndpoints: assert resp == {"result": {"text": "ok"}} - def test_completion_api_conversation_not_exists(self, app, monkeypatch): + def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -142,7 +143,7 @@ class TestCompletionEndpoints: with pytest.raises(NotFound): method(app_model=MagicMock(id="app-1")) - def test_completion_api_provider_not_initialized(self, app, monkeypatch): + def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -166,7 +167,7 @@ class TestCompletionEndpoints: with pytest.raises(completion_module.ProviderNotInitializeError): method(app_model=MagicMock(id="app-1")) - def test_completion_api_quota_exceeded(self, app, monkeypatch): + def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() method = _unwrap(api.post) @@ -193,10 +194,10 @@ class TestCompletionEndpoints: class TestAppEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppApi() method = _unwrap(api.put) payload = { @@ -234,7 +235,7 @@ class TestAppEndpoints: } ) - def test_app_icon_post_should_forward_icon_type(self, app, monkeypatch): + def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppIconApi() method = _unwrap(api.post) payload = { @@ -266,7 +267,7 @@ class TestAppEndpoints: class TestOpsTraceEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_ops_trace_query_basic(self): @@ -277,7 +278,7 @@ class TestOpsTraceEndpoints: payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"}) assert payload.tracing_config["api_key"] == "k" - def test_trace_app_config_get_empty(self, app, monkeypatch): + def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.get) @@ -292,7 +293,7 @@ class TestOpsTraceEndpoints: assert result == {"has_not_configured": True} - def test_trace_app_config_post_invalid(self, app, monkeypatch): + def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.post) @@ -309,7 +310,7 @@ class TestOpsTraceEndpoints: with pytest.raises(BadRequest): method(app_id="app-1") - def test_trace_app_config_delete_not_found(self, app, monkeypatch): + def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() method = _unwrap(api.delete) @@ -326,7 +327,7 @@ class TestOpsTraceEndpoints: class TestSiteEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_site_response_structure(self): @@ -337,7 +338,7 @@ class TestSiteEndpoints: payload = AppSiteUpdatePayload(default_language="en-US") assert payload.default_language == "en-US" - def test_app_site_update_post(self, app, monkeypatch): + def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSite() method = _unwrap(api.post) @@ -375,7 +376,7 @@ class TestSiteEndpoints: assert isinstance(result, dict) assert result["title"] == "My Site" - def test_app_site_access_token_reset(self, app, monkeypatch): + def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSiteAccessTokenReset() method = _unwrap(api.post) @@ -427,7 +428,7 @@ class TestWorkflowEndpoints: class TestWorkflowAppLogEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_app_log_query(self): @@ -438,7 +439,7 @@ class TestWorkflowAppLogEndpoints: query = WorkflowAppLogQuery(detail="true") assert query.detail is True - def test_workflow_app_log_api_get(self, app, monkeypatch): + def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_app_log_module.WorkflowAppLogApi() method = _unwrap(api.get) @@ -477,14 +478,14 @@ class TestWorkflowAppLogEndpoints: class TestWorkflowDraftVariableEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_variable_creation(self): payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test") assert payload.name == "var1" - def test_workflow_variable_collection_get(self, app, monkeypatch): + def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_draft_variable_module.WorkflowVariableCollectionApi() method = _unwrap(api.get) @@ -529,7 +530,7 @@ class TestWorkflowDraftVariableEndpoints: class TestWorkflowStatisticEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_workflow_statistic_time_range(self): @@ -541,7 +542,7 @@ class TestWorkflowStatisticEndpoints: assert query.start is None assert query.end is None - def test_workflow_daily_runs_statistic(self, app, monkeypatch): + def test_workflow_daily_runs_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -567,7 +568,7 @@ class TestWorkflowStatisticEndpoints: assert response.get_json() == {"data": [{"date": "2024-01-01"}]} - def test_workflow_daily_terminals_statistic(self, app, monkeypatch): + def test_workflow_daily_terminals_statistic(self, app: Flask, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) monkeypatch.setattr( workflow_statistic_module.DifyAPIRepositoryFactory, @@ -598,7 +599,7 @@ class TestWorkflowStatisticEndpoints: class TestWorkflowTriggerEndpoints: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_webhook_trigger_payload(self): @@ -608,7 +609,7 @@ class TestWorkflowTriggerEndpoints: enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True) assert enable_payload.enable_trigger is True - def test_webhook_trigger_api_get(self, app, monkeypatch): + def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_trigger_module.WebhookTriggerApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index 25d19cf35a..bcb6e41ef7 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from controllers.console.app import app_import as app_import_module from services.app_dsl_service import ImportStatus @@ -36,10 +37,10 @@ def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: class TestAppImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_post_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -57,7 +58,7 @@ class TestAppImportApi: assert status == 400 assert response["status"] == ImportStatus.FAILED - def test_import_post_returns_pending_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -75,7 +76,7 @@ class TestAppImportApi: assert status == 202 assert response["status"] == ImportStatus.PENDING - def test_import_post_updates_webapp_auth_when_enabled(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -96,7 +97,7 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED - def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -121,7 +122,7 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED - def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() method = _unwrap(api.post) @@ -149,10 +150,10 @@ class TestAppImportApi: class TestAppImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_confirm_returns_failed_status(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportConfirmApi() method = _unwrap(api.post) @@ -172,10 +173,10 @@ class TestAppImportConfirmApi: class TestAppImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_import_check_dependencies_returns_result(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportCheckDependenciesApi() method = _unwrap(api.get) diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py index 320da85b60..1fcce9ca44 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_email_register.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.email_register import ( EmailRegisterCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -33,7 +34,7 @@ class TestEmailRegisterSendEmailApi: mock_is_freeze, mock_send_mail, mock_get_account, - app, + app: Flask, ): mock_send_mail.return_value = "token-123" mock_is_freeze.return_value = False @@ -75,7 +76,7 @@ class TestEmailRegisterCheckApi: mock_revoke, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "4321"} @@ -120,7 +121,7 @@ class TestEmailRegisterResetApi: mock_create_account, mock_login, mock_reset_login_rate, - app, + app: Flask, ): mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"} mock_create_account.return_value = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py index d2703ed5cc..014c1588fe 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_forgot_password.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.forgot_password import ( ForgotPasswordCheckApi, @@ -16,7 +17,7 @@ from services.account_service import AccountService @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): return flask_app_with_containers @@ -31,7 +32,7 @@ class TestForgotPasswordSendEmailApi: mock_is_ip_limit, mock_send_email, mock_get_account, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -80,7 +81,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_rate_limit_check.return_value = False mock_get_data.return_value = {"email": "Admin@Example.com", "code": "4321"} @@ -123,7 +124,7 @@ class TestForgotPasswordResetApi: mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com"} mock_account = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py index 1eabb45422..01d88d247c 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_oauth.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.oauth import ( OAuthCallback, @@ -21,7 +22,7 @@ from services.errors.account import AccountRegisterError class TestGetOAuthProviders: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.mark.parametrize( @@ -65,7 +66,7 @@ class TestOAuthLogin: return OAuthLogin() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -130,7 +131,7 @@ class TestOAuthCallback: return OAuthCallback() @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -394,7 +395,7 @@ class TestOAuthCallback: class TestAccountGeneration: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py index 50249bcd74..8d6b25b5b3 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_password_reset.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console.auth.error import ( EmailCodeError, @@ -25,7 +26,7 @@ class TestForgotPasswordSendEmailApi: """Test cases for sending password reset emails.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -68,7 +69,7 @@ class TestForgotPasswordSendEmailApi: mock_send_email.assert_called_once() @patch("controllers.console.auth.forgot_password.AccountService.is_email_send_ip_limit") - def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app): + def test_send_reset_email_ip_rate_limited(self, mock_is_ip_limit, app: Flask): """ Test password reset email blocked by IP rate limit. @@ -138,7 +139,7 @@ class TestForgotPasswordCheckApi: """Test cases for verifying password reset codes.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @@ -221,7 +222,7 @@ class TestForgotPasswordCheckApi: mock_reset_rate_limit.assert_called_once_with("user@example.com") @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") - def test_verify_code_rate_limited(self, mock_is_rate_limit, app): + def test_verify_code_rate_limited(self, mock_is_rate_limit, app: Flask): """ Test code verification blocked by rate limit. @@ -244,7 +245,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_invalid_token(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with invalid token. @@ -267,7 +268,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_email_mismatch(self, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with mismatched email. @@ -292,7 +293,7 @@ class TestForgotPasswordCheckApi: @patch("controllers.console.auth.forgot_password.AccountService.is_forgot_password_error_rate_limit") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.add_forgot_password_error_rate_limit") - def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app): + def test_verify_code_wrong_code(self, mock_add_rate_limit, mock_get_data, mock_is_rate_limit, app: Flask): """ Test code verification with incorrect code. @@ -321,7 +322,7 @@ class TestForgotPasswordResetApi: """Test cases for resetting password with verified token.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -375,7 +376,7 @@ class TestForgotPasswordResetApi: mock_revoke_token.assert_called_once_with("valid_token") @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_mismatch(self, mock_get_data, app): + def test_reset_password_mismatch(self, mock_get_data, app: Flask): """ Test password reset with mismatched passwords. @@ -397,7 +398,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_invalid_token(self, mock_get_data, app): + def test_reset_password_invalid_token(self, mock_get_data, app: Flask): """ Test password reset with invalid token. @@ -418,7 +419,7 @@ class TestForgotPasswordResetApi: api.post() @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") - def test_reset_password_wrong_phase(self, mock_get_data, app): + def test_reset_password_wrong_phase(self, mock_get_data, app: Flask): """ Test password reset with token not in reset phase. @@ -442,7 +443,7 @@ class TestForgotPasswordResetApi: @patch("controllers.console.auth.forgot_password.AccountService.get_reset_password_data") @patch("controllers.console.auth.forgot_password.AccountService.revoke_reset_password_token") @patch("controllers.console.auth.forgot_password.AccountService.get_account_by_email_with_case_fallback") - def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app): + def test_reset_password_account_not_found(self, mock_get_account, mock_revoke_token, mock_get_data, app: Flask): """ Test password reset for non-existent account. diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index d5ae95dfb7..2752e6b34f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from controllers.console import console_ns @@ -26,7 +27,7 @@ def unwrap(func): class TestPipelineTemplateListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -50,7 +51,7 @@ class TestPipelineTemplateListApi: class TestPipelineTemplateDetailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -115,7 +116,7 @@ class TestPipelineTemplateDetailApi: class TestCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_patch_success(self, app): @@ -193,7 +194,7 @@ class TestCustomizedPipelineTemplateApi: class TestPublishCustomizedPipelineTemplateApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_post_success(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py index 64e3de2ca3..7624c1150f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden import services @@ -24,13 +25,13 @@ def unwrap(func): class TestCreateRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _valid_payload(self): return {"yaml_content": "name: test"} - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -58,7 +59,7 @@ class TestCreateRagPipelineDatasetApi: assert status == 201 assert response == import_info - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -76,7 +77,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(Forbidden): method(api) - def test_post_dataset_name_duplicate(self, app): + def test_post_dataset_name_duplicate(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -101,7 +102,7 @@ class TestCreateRagPipelineDatasetApi: with pytest.raises(DatasetNameDuplicateError): method(api) - def test_post_invalid_payload(self, app): + def test_post_invalid_payload(self, app: Flask): api = CreateRagPipelineDatasetApi() method = unwrap(api.post) @@ -122,10 +123,10 @@ class TestCreateRagPipelineDatasetApi: class TestCreateEmptyRagPipelineDatasetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_post_success(self, app): + def test_post_success(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) @@ -152,7 +153,7 @@ class TestCreateEmptyRagPipelineDatasetApi: assert status == 201 assert response == {"id": "ds-1"} - def test_post_forbidden_non_editor(self, app): + def test_post_forbidden_non_editor(self, app: Flask): api = CreateEmptyRagPipelineDatasetApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py index cb67892878..f238ca13ee 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.console import console_ns from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( @@ -25,7 +26,7 @@ def unwrap(func): class TestRagPipelineImportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _payload(self, mode="create"): @@ -128,7 +129,7 @@ class TestRagPipelineImportApi: class TestRagPipelineImportConfirmApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_confirm_success(self, app): @@ -190,7 +191,7 @@ class TestRagPipelineImportConfirmApi: class TestRagPipelineImportCheckDependenciesApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app): @@ -219,7 +220,7 @@ class TestRagPipelineImportCheckDependenciesApi: class TestRagPipelineExportApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_with_include_secret(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index c1f3122c2b..1fdb3057b8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound @@ -45,10 +46,10 @@ def unwrap(func): class TestDraftWorkflowApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_draft_success(self, app): + def test_get_draft_success(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -68,7 +69,7 @@ class TestDraftWorkflowApi: result = method(api, pipeline) assert result == workflow - def test_get_draft_not_exist(self, app): + def test_get_draft_not_exist(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.get) @@ -86,7 +87,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotExist): method(api, pipeline) - def test_sync_hash_not_match(self, app): + def test_sync_hash_not_match(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -111,7 +112,7 @@ class TestDraftWorkflowApi: with pytest.raises(DraftWorkflowNotSync): method(api, pipeline) - def test_sync_invalid_text_plain(self, app): + def test_sync_invalid_text_plain(self, app: Flask): api = DraftRagPipelineApi() method = unwrap(api.post) @@ -128,7 +129,7 @@ class TestDraftWorkflowApi: response, status = method(api, pipeline) assert status == 400 - def test_restore_published_workflow_to_draft_success(self, app): + def test_restore_published_workflow_to_draft_success(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -155,7 +156,7 @@ class TestDraftWorkflowApi: assert result["result"] == "success" assert result["hash"] == "restored-hash" - def test_restore_published_workflow_to_draft_not_found(self, app): + def test_restore_published_workflow_to_draft_not_found(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -179,7 +180,7 @@ class TestDraftWorkflowApi: with pytest.raises(NotFound): method(api, pipeline, "published-workflow") - def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app: Flask): api = RagPipelineDraftWorkflowRestoreApi() method = unwrap(api.post) @@ -211,10 +212,10 @@ class TestDraftWorkflowApi: class TestDraftRunNodes: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_iteration_node_success(self, app): + def test_iteration_node_success(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -240,7 +241,7 @@ class TestDraftRunNodes: result = method(api, pipeline, "node") assert result == {"ok": True} - def test_iteration_node_conversation_not_exists(self, app): + def test_iteration_node_conversation_not_exists(self, app: Flask): api = RagPipelineDraftRunIterationNodeApi() method = unwrap(api.post) @@ -262,7 +263,7 @@ class TestDraftRunNodes: with pytest.raises(NotFound): method(api, pipeline, "node") - def test_loop_node_success(self, app): + def test_loop_node_success(self, app: Flask): api = RagPipelineDraftRunLoopNodeApi() method = unwrap(api.post) @@ -290,10 +291,10 @@ class TestDraftRunNodes: class TestPipelineRunApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_draft_run_success(self, app): + def test_draft_run_success(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -325,7 +326,7 @@ class TestPipelineRunApis: ): assert method(api, pipeline) == {"ok": True} - def test_draft_run_rate_limit(self, app): + def test_draft_run_rate_limit(self, app: Flask): api = DraftRagPipelineRunApi() method = unwrap(api.post) @@ -356,10 +357,10 @@ class TestPipelineRunApis: class TestDraftNodeRun: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_execution_not_found(self, app): + def test_execution_not_found(self, app: Flask): api = RagPipelineDraftNodeRunApi() method = unwrap(api.post) @@ -387,7 +388,7 @@ class TestDraftNodeRun: class TestPublishedPipelineApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_publish_success(self, app, db_session_with_containers: Session): @@ -436,10 +437,10 @@ class TestPublishedPipelineApis: class TestMiscApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_task_stop(self, app): + def test_task_stop(self, app: Flask): api = RagPipelineTaskStopApi() method = unwrap(api.post) @@ -460,7 +461,7 @@ class TestMiscApis: stop_mock.assert_called_once() assert result["result"] == "success" - def test_transform_forbidden(self, app): + def test_transform_forbidden(self, app: Flask): api = RagPipelineTransformApi() method = unwrap(api.post) @@ -476,7 +477,7 @@ class TestMiscApis: with pytest.raises(Forbidden): method(api, "ds1") - def test_recommended_plugins(self, app): + def test_recommended_plugins(self, app: Flask): api = RagPipelineRecommendedPluginApi() method = unwrap(api.get) @@ -496,10 +497,10 @@ class TestMiscApis: class TestPublishedRagPipelineRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_published_run_success(self, app): + def test_published_run_success(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -533,7 +534,7 @@ class TestPublishedRagPipelineRunApi: result = method(api, pipeline) assert result == {"ok": True} - def test_published_run_rate_limit(self, app): + def test_published_run_rate_limit(self, app: Flask): api = PublishedRagPipelineRunApi() method = unwrap(api.post) @@ -565,10 +566,10 @@ class TestPublishedRagPipelineRunApi: class TestDefaultBlockConfigApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_block_config_success(self, app): + def test_get_block_config_success(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -587,7 +588,7 @@ class TestDefaultBlockConfigApi: result = method(api, pipeline, "llm") assert result == {"k": "v"} - def test_get_block_config_invalid_json(self, app): + def test_get_block_config_invalid_json(self, app: Flask): api = DefaultRagPipelineBlockConfigApi() method = unwrap(api.get) @@ -600,10 +601,10 @@ class TestDefaultBlockConfigApi: class TestPublishedAllRagPipelineApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_published_workflows_success(self, app): + def test_get_published_workflows_success(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -629,7 +630,7 @@ class TestPublishedAllRagPipelineApi: assert result["items"] == [{"id": "w1"}] assert result["has_more"] is False - def test_get_published_workflows_forbidden(self, app): + def test_get_published_workflows_forbidden(self, app: Flask): api = PublishedAllRagPipelineApi() method = unwrap(api.get) @@ -649,10 +650,10 @@ class TestPublishedAllRagPipelineApi: class TestRagPipelineByIdApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_patch_success(self, app): + def test_patch_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -682,7 +683,7 @@ class TestRagPipelineByIdApi: assert result == workflow - def test_patch_no_fields(self, app): + def test_patch_no_fields(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.patch) @@ -700,7 +701,7 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 - def test_delete_success(self, app): + def test_delete_success(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -720,7 +721,7 @@ class TestRagPipelineByIdApi: workflow_service.delete_workflow.assert_called_once() assert result == (None, 204) - def test_delete_active_workflow_rejected(self, app): + def test_delete_active_workflow_rejected(self, app: Flask): api = RagPipelineByIdApi() method = unwrap(api.delete) @@ -733,10 +734,10 @@ class TestRagPipelineByIdApi: class TestRagPipelineWorkflowLastRunApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_last_run_success(self, app): + def test_last_run_success(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -758,7 +759,7 @@ class TestRagPipelineWorkflowLastRunApi: result = method(api, pipeline, "node1") assert result == node_exec - def test_last_run_not_found(self, app): + def test_last_run_not_found(self, app: Flask): api = RagPipelineWorkflowLastRunApi() method = unwrap(api.get) @@ -780,10 +781,10 @@ class TestRagPipelineWorkflowLastRunApi: class TestRagPipelineDatasourceVariableApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_set_datasource_variables_success(self, app): + def test_set_datasource_variables_success(self, app: Flask): api = RagPipelineDatasourceVariableApi() method = unwrap(api.post) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 1c4c6a899f..50ad92afa1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, PropertyMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.console.datasets import data_source @@ -51,7 +52,7 @@ def mock_engine(): class TestDataSourceApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): @@ -188,7 +189,7 @@ class TestDataSourceApi: class TestDataSourceNotionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_credential_not_found(self, app, patch_tenant): @@ -323,7 +324,7 @@ class TestDataSourceNotionListApi: class TestDataSourceNotionApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_preview_success(self, app, patch_tenant): @@ -381,7 +382,7 @@ class TestDataSourceNotionApi: class TestDataSourceNotionDatasetSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): @@ -424,7 +425,7 @@ class TestDataSourceNotionDatasetSyncApi: class TestDataSourceNotionDocumentSyncApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, patch_tenant): diff --git a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py index 83492048ef..0b53ca5585 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/console/explore/test_conversation.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import NotFound import controllers.console.explore.conversation as conversation_module @@ -53,7 +54,7 @@ def user(): class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_success(self, app, chat_app, user): @@ -108,7 +109,7 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_delete_success(self, app, chat_app, user): @@ -156,7 +157,7 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_rename_success(self, app, chat_app, user): @@ -197,7 +198,7 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_pin_success(self, app, chat_app, user): @@ -219,7 +220,7 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_unpin_success(self, app, chat_app, user): diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py index f2e7104b18..d944613886 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_tool_provider.py @@ -6,6 +6,7 @@ import json from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import Forbidden from controllers.console.workspace.tool_providers import ( @@ -60,7 +61,7 @@ def _mock_user_tenant(): @pytest.fixture -def client(flask_app_with_containers): +def client(flask_app_with_containers: Flask): return flask_app_with_containers.test_client() @@ -147,10 +148,10 @@ class TestUtils: class TestToolProviderListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_get_success(self, app): + def test_get_success(self, app: Flask): api = ToolProviderListApi() method = unwrap(api.get) @@ -170,10 +171,10 @@ class TestToolProviderListApi: class TestBuiltinProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolBuiltinProviderListToolsApi() method = unwrap(api.get) @@ -190,7 +191,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == [{"a": 1}] - def test_info(self, app): + def test_info(self, app: Flask): api = ToolBuiltinProviderInfoApi() method = unwrap(api.get) @@ -207,7 +208,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"x": 1} - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolBuiltinProviderDeleteApi() method = unwrap(api.post) @@ -224,7 +225,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["result"] == "success" - def test_add_invalid_type(self, app): + def test_add_invalid_type(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -238,7 +239,7 @@ class TestBuiltinProviderApis: with pytest.raises(ValueError): method(api, "provider") - def test_add_success(self, app): + def test_add_success(self, app: Flask): api = ToolBuiltinProviderAddApi() method = unwrap(api.post) @@ -257,7 +258,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["id"] == 1 - def test_update(self, app): + def test_update(self, app: Flask): api = ToolBuiltinProviderUpdateApi() method = unwrap(api.post) @@ -276,7 +277,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credentials(self, app): + def test_get_credentials(self, app: Flask): api = ToolBuiltinProviderGetCredentialsApi() method = unwrap(api.get) @@ -293,7 +294,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"k": "v"} - def test_icon(self, app): + def test_icon(self, app: Flask): api = ToolBuiltinProviderIconApi() method = unwrap(api.get) @@ -307,7 +308,7 @@ class TestBuiltinProviderApis: response = method(api, "provider") assert response.mimetype == "image/png" - def test_credentials_schema(self, app): + def test_credentials_schema(self, app: Flask): api = ToolBuiltinProviderCredentialsSchemaApi() method = unwrap(api.get) @@ -324,7 +325,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider", "oauth2") == {"schema": {}} - def test_set_default_credential(self, app): + def test_set_default_credential(self, app: Flask): api = ToolBuiltinProviderSetDefaultApi() method = unwrap(api.post) @@ -341,7 +342,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider")["ok"] - def test_get_credential_info(self, app): + def test_get_credential_info(self, app: Flask): api = ToolBuiltinProviderGetCredentialInfoApi() method = unwrap(api.get) @@ -358,7 +359,7 @@ class TestBuiltinProviderApis: ): assert method(api, "provider") == {"info": "x"} - def test_get_oauth_client_schema(self, app): + def test_get_oauth_client_schema(self, app: Flask): api = ToolBuiltinProviderGetOauthClientSchemaApi() method = unwrap(api.get) @@ -378,10 +379,10 @@ class TestBuiltinProviderApis: class TestApiProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_add(self, app): + def test_add(self, app: Flask): api = ToolApiProviderAddApi() method = unwrap(api.post) @@ -406,7 +407,7 @@ class TestApiProviderApis: ): assert method(api)["id"] == 1 - def test_remote_schema(self, app): + def test_remote_schema(self, app: Flask): api = ToolApiProviderGetRemoteSchemaApi() method = unwrap(api.get) @@ -423,7 +424,7 @@ class TestApiProviderApis: ): assert method(api)["schema"] == "x" - def test_list_tools(self, app): + def test_list_tools(self, app: Flask): api = ToolApiProviderListToolsApi() method = unwrap(api.get) @@ -440,7 +441,7 @@ class TestApiProviderApis: ): assert method(api) == [{"tool": 1}] - def test_update(self, app): + def test_update(self, app: Flask): api = ToolApiProviderUpdateApi() method = unwrap(api.post) @@ -468,7 +469,7 @@ class TestApiProviderApis: ): assert method(api)["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolApiProviderDeleteApi() method = unwrap(api.post) @@ -485,7 +486,7 @@ class TestApiProviderApis: ): assert method(api)["result"] == "success" - def test_get(self, app): + def test_get(self, app: Flask): api = ToolApiProviderGetApi() method = unwrap(api.get) @@ -505,10 +506,10 @@ class TestApiProviderApis: class TestWorkflowApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_create(self, app): + def test_create(self, app: Flask): api = ToolWorkflowProviderCreateApi() method = unwrap(api.post) @@ -534,7 +535,7 @@ class TestWorkflowApis: ): assert method(api)["id"] == 1 - def test_update_invalid(self, app): + def test_update_invalid(self, app: Flask): api = ToolWorkflowProviderUpdateApi() method = unwrap(api.post) @@ -560,7 +561,7 @@ class TestWorkflowApis: result = method(api) assert result["ok"] - def test_delete(self, app): + def test_delete(self, app: Flask): api = ToolWorkflowProviderDeleteApi() method = unwrap(api.post) @@ -577,7 +578,7 @@ class TestWorkflowApis: ): assert method(api)["ok"] - def test_get_error(self, app): + def test_get_error(self, app: Flask): api = ToolWorkflowProviderGetApi() method = unwrap(api.get) @@ -594,10 +595,10 @@ class TestWorkflowApis: class TestLists: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_builtin_list(self, app): + def test_builtin_list(self, app: Flask): api = ToolBuiltinListApi() method = unwrap(api.get) @@ -617,7 +618,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_api_list(self, app): + def test_api_list(self, app: Flask): api = ToolApiListApi() method = unwrap(api.get) @@ -637,7 +638,7 @@ class TestLists: ): assert method(api) == [{"x": 1}] - def test_workflow_list(self, app): + def test_workflow_list(self, app: Flask): api = ToolWorkflowListApi() method = unwrap(api.get) @@ -660,10 +661,10 @@ class TestLists: class TestLabels: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_labels(self, app): + def test_labels(self, app: Flask): api = ToolLabelsApi() method = unwrap(api.get) @@ -679,10 +680,10 @@ class TestLabels: class TestOAuth: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_oauth_no_client(self, app): + def test_oauth_no_client(self, app: Flask): api = ToolPluginOAuthApi() method = unwrap(api.get) @@ -700,7 +701,7 @@ class TestOAuth: with pytest.raises(Forbidden): method(api, "provider") - def test_oauth_callback_no_cookie(self, app): + def test_oauth_callback_no_cookie(self, app: Flask): api = ToolOAuthCallback() method = unwrap(api.get) @@ -711,10 +712,10 @@ class TestOAuth: class TestOAuthCustomClient: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_save_custom_client(self, app): + def test_save_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.post) @@ -731,7 +732,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider")["ok"] - def test_get_custom_client(self, app): + def test_get_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.get) @@ -748,7 +749,7 @@ class TestOAuthCustomClient: ): assert method(api, "provider") == {"client_id": "x"} - def test_delete_custom_client(self, app): + def test_delete_custom_client(self, app: Flask): api = ToolOAuthCustomClient() method = unwrap(api.delete) diff --git a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py index ca8195af53..6efdaf2943 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py +++ b/api/tests/test_containers_integration_tests/controllers/console/workspace/test_trigger_providers.py @@ -5,6 +5,7 @@ from __future__ import annotations from unittest.mock import MagicMock, patch import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden from controllers.console.workspace.trigger_providers import ( @@ -45,7 +46,7 @@ def mock_user(): class TestTriggerProviderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_icon_success(self, app): @@ -93,7 +94,7 @@ class TestTriggerProviderApis: class TestTriggerSubscriptionListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_list_success(self, app): @@ -128,7 +129,7 @@ class TestTriggerSubscriptionListApi: class TestTriggerSubscriptionBuilderApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_create_builder(self, app): @@ -236,7 +237,7 @@ class TestTriggerSubscriptionBuilderApis: class TestTriggerSubscriptionCrud: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_update_rename_only(self, app): @@ -342,7 +343,7 @@ class TestTriggerSubscriptionCrud: class TestTriggerOAuthApis: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_oauth_authorize_success(self, app): @@ -480,7 +481,7 @@ class TestTriggerOAuthApis: class TestTriggerOAuthClientManageApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_get_client(self, app): @@ -556,7 +557,7 @@ class TestTriggerOAuthClientManageApi: class TestTriggerSubscriptionVerifyApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def test_verify_success(self, app): diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 9b913d6d3d..5791d2f6e2 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -18,6 +18,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden, NotFound @@ -217,10 +218,20 @@ class TestTagUnbindingPayload: """Test suite for TagUnbindingPayload Pydantic model.""" def test_payload_with_valid_data(self): - payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456") - assert payload.tag_id == "tag_123" + payload = TagUnbindingPayload(tag_ids=["tag_123"], target_id="dataset_456") + assert payload.tag_ids == ["tag_123"] assert payload.target_id == "dataset_456" + def test_payload_normalizes_legacy_tag_id(self): + payload = TagUnbindingPayload(tag_id="tag_123", target_id="dataset_456") + assert payload.tag_ids == ["tag_123"] + assert payload.target_id == "dataset_456" + + def test_payload_rejects_empty_tag_ids(self): + with pytest.raises(ValueError) as exc_info: + TagUnbindingPayload(tag_ids=[], target_id="dataset_456") + assert "Tag IDs is required" in str(exc_info.value) + # --------------------------------------------------------------------------- # Helpers @@ -236,7 +247,7 @@ def _unwrap(method): @pytest.fixture -def app(flask_app_with_containers): +def app(flask_app_with_containers: Flask): # Uses the full containerised app so that Flask config, extensions, and # blueprint registrations match production. Most tests mock the service # layer to isolate controller logic; a few (e.g. test_list_tags_from_db) @@ -1012,6 +1023,36 @@ class TestDatasetTagUnbindingApiPost: mock_current_user.is_dataset_editor = True mock_tag_svc.delete_tag_binding.return_value = None + with app.test_request_context( + "/datasets/tags/unbinding", + method="POST", + json={"tag_ids": ["tag-1"], "target_id": "ds-1"}, + ): + api = DatasetTagUnbindingApi() + result = api.post(_=None) + + assert result == ("", 204) + from services.tag_service import TagBindingDeletePayload + + mock_tag_svc.delete_tag_binding.assert_called_once_with( + TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") + ) + + @patch("controllers.service_api.dataset.dataset.TagService") + @patch("controllers.service_api.dataset.dataset.current_user") + def test_unbind_legacy_tag_id_success( + self, + mock_current_user, + mock_tag_svc, + app, + ): + from controllers.service_api.dataset.dataset import DatasetTagUnbindingApi + + mock_current_user.__class__ = Account + mock_current_user.has_edit_permission = True + mock_current_user.is_dataset_editor = True + mock_tag_svc.delete_tag_binding.return_value = None + with app.test_request_context( "/datasets/tags/unbinding", method="POST", @@ -1024,7 +1065,7 @@ class TestDatasetTagUnbindingApiPost: from services.tag_service import TagBindingDeletePayload mock_tag_svc.delete_tag_binding.assert_called_once_with( - TagBindingDeletePayload(tag_id="tag-1", target_id="ds-1", type="knowledge") + TagBindingDeletePayload(tag_ids=["tag-1"], target_id="ds-1", type="knowledge") ) @patch("controllers.service_api.dataset.dataset.current_user") @@ -1038,7 +1079,7 @@ class TestDatasetTagUnbindingApiPost: with app.test_request_context( "/datasets/tags/unbinding", method="POST", - json={"tag_id": "tag-1", "target_id": "ds-1"}, + json={"tag_ids": ["tag-1"], "target_id": "ds-1"}, ): api = DatasetTagUnbindingApi() with pytest.raises(Forbidden): diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py index e1e6741014..c34da27ebe 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_conversation.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.web.conversation import ( @@ -34,16 +35,16 @@ def _end_user() -> SimpleNamespace: class TestConversationListApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context("/conversations"): with pytest.raises(NotChatAppError): ConversationListApi().get(_completion_app(), _end_user()) @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") - def test_happy_path(self, mock_paginate: MagicMock, app) -> None: + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: conv_id = str(uuid4()) conv = SimpleNamespace( id=conv_id, @@ -65,16 +66,16 @@ class TestConversationListApi: class TestConversationApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}"): with pytest.raises(NotChatAppError): ConversationApi().delete(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.delete") - def test_delete_success(self, mock_delete: MagicMock, app) -> None: + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) @@ -83,7 +84,7 @@ class TestConversationApi: assert result["result"] == "success" @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) - def test_delete_not_found(self, mock_delete: MagicMock, app) -> None: + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}"): with pytest.raises(NotFound, match="Conversation Not Exists"): @@ -92,17 +93,17 @@ class TestConversationApi: class TestConversationRenameApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): with pytest.raises(NotChatAppError): ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.ConversationService.rename") @patch("controllers.web.conversation.web_ns") - def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "New Name", "auto_generate": False} conv = SimpleNamespace( @@ -126,7 +127,7 @@ class TestConversationRenameApi: side_effect=ConversationNotExistsError(), ) @patch("controllers.web.conversation.web_ns") - def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app) -> None: + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: c_id = uuid4() mock_ns.payload = {"name": "X", "auto_generate": False} @@ -137,16 +138,16 @@ class TestConversationRenameApi: class TestConversationPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.pin") - def test_pin_success(self, mock_pin: MagicMock, app) -> None: + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) @@ -154,7 +155,7 @@ class TestConversationPinApi: assert result["result"] == "success" @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) - def test_pin_not_found(self, mock_pin: MagicMock, app) -> None: + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): with pytest.raises(NotFound): @@ -163,16 +164,16 @@ class TestConversationPinApi: class TestConversationUnPinApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers - def test_non_chat_mode_raises(self, app) -> None: + def test_non_chat_mode_raises(self, app: Flask) -> None: with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): with pytest.raises(NotChatAppError): ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) @patch("controllers.web.conversation.WebConversationService.unpin") - def test_unpin_success(self, mock_unpin: MagicMock, app) -> None: + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: c_id = uuid4() with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py index 635cfee2da..2c6a990240 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_web_forgot_password.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from flask import Flask from controllers.web.forgot_password import ( ForgotPasswordCheckApi, @@ -29,7 +30,7 @@ def _patch_wraps(): class TestForgotPasswordSendEmailApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.send_reset_password_email") @@ -42,7 +43,7 @@ class TestForgotPasswordSendEmailApi: mock_rate_limit, mock_get_account, mock_send_mail, - app, + app: Flask, ): mock_account = MagicMock() mock_get_account.return_value = mock_account @@ -64,7 +65,7 @@ class TestForgotPasswordSendEmailApi: class TestForgotPasswordCheckApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.AccountService.reset_forgot_password_error_rate_limit") @@ -81,7 +82,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "User@Example.com", "code": "1234"} @@ -117,7 +118,7 @@ class TestForgotPasswordCheckApi: mock_revoke_token, mock_generate_token, mock_reset_rate, - app, + app: Flask, ): mock_is_rate_limit.return_value = False mock_get_data.return_value = {"email": "MixedCase@Example.com", "code": "5678"} @@ -142,7 +143,7 @@ class TestForgotPasswordCheckApi: class TestForgotPasswordResetApi: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @patch("controllers.web.forgot_password.ForgotPasswordResetApi._update_existing_account") @@ -157,7 +158,7 @@ class TestForgotPasswordResetApi: mock_db, mock_get_account, mock_update_account, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "User@Example.com", "code": "1234"} mock_account = MagicMock() @@ -194,7 +195,7 @@ class TestForgotPasswordResetApi: mock_db, mock_token_bytes, mock_hash_password, - app, + app: Flask, ): mock_get_reset_data.return_value = {"phase": "reset", "email": "user@example.com"} account = MagicMock() diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py index 19833cc772..de9e691434 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_wraps.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, NotFound, Unauthorized @@ -182,7 +183,7 @@ class TestValidateUserAccessibility: class TestDecodeJwtToken: @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers def _create_app_site_enduser(self, db_session: Session, *, enable_site: bool = True): diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index c342e8994b..bd13527e14 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -85,7 +85,7 @@ class TestPauseStatePersistenceLayerTestContainers: return WorkflowRunService(engine) @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers, file_service, workflow_run_service): + def setup_test_data(self, db_session_with_containers: Session, file_service, workflow_run_service): """Set up test data for each test method using TestContainers.""" # Create test tenant and account from models.account import AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus @@ -295,7 +295,7 @@ class TestPauseStatePersistenceLayerTestContainers: generate_entity=entity, ) - def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers): + def test_complete_pause_flow_with_real_dependencies(self, db_session_with_containers: Session): """Test complete pause flow: event -> state serialization -> database save -> storage save.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -352,7 +352,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert isinstance(persisted_entity, WorkflowAppGenerateEntity) assert persisted_entity.workflow_execution_id == self.test_workflow_run_id - def test_state_persistence_and_retrieval(self, db_session_with_containers): + def test_state_persistence_and_retrieval(self, db_session_with_containers: Session): """Test that pause state can be persisted and retrieved correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -402,7 +402,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert retrieved_state["node_run_steps"] == 10 assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_database_transaction_handling(self, db_session_with_containers): + def test_database_transaction_handling(self, db_session_with_containers: Session): """Test that database transactions are handled correctly.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -433,7 +433,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert pause_model.resumed_at is None assert pause_model.state_object_key != "" - def test_file_storage_integration(self, db_session_with_containers): + def test_file_storage_integration(self, db_session_with_containers: Session): """Test integration with file storage system.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -467,7 +467,7 @@ class TestPauseStatePersistenceLayerTestContainers: assert resumption_context.serialized_graph_runtime_state == graph_runtime_state.dumps() assert resumption_context.get_generate_entity().workflow_execution_id == self.test_workflow_run_id - def test_workflow_with_different_creators(self, db_session_with_containers): + def test_workflow_with_different_creators(self, db_session_with_containers: Session): """Test pause state with workflows created by different users.""" # Arrange - Create workflow with different creator different_user_id = str(uuid.uuid4()) @@ -532,7 +532,7 @@ class TestPauseStatePersistenceLayerTestContainers: resumption_context = WorkflowResumptionContext.loads(pause_entity.get_state().decode()) assert resumption_context.get_generate_entity().workflow_execution_id == different_workflow_run.id - def test_layer_ignores_non_pause_events(self, db_session_with_containers): + def test_layer_ignores_non_pause_events(self, db_session_with_containers: Session): """Test that layer ignores non-pause events.""" # Arrange layer = self._create_pause_state_persistence_layer() @@ -562,7 +562,7 @@ class TestPauseStatePersistenceLayerTestContainers: ).all() assert len(pause_states) == 0 - def test_layer_requires_initialization(self, db_session_with_containers): + def test_layer_requires_initialization(self, db_session_with_containers: Session): """Test that layer requires proper initialization before handling events.""" # Arrange layer = self._create_pause_state_persistence_layer() diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index a60159c66a..54ee133bfe 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -15,6 +15,7 @@ from uuid import uuid4 import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client @@ -40,7 +41,7 @@ class TestTenantIsolatedTaskQueueIntegration: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -94,7 +95,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert queue._queue == f"tenant_self_test-key_task_queue:{tenant.id}" assert queue._task_key == f"tenant_test-key_task:{tenant.id}" - def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers, fake): + def test_tenant_isolation(self, test_tenant_and_account, db_session_with_containers: Session, fake: Faker): """Test that different tenants have isolated queues.""" tenant1, _ = test_tenant_and_account @@ -176,7 +177,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert len(remaining_tasks) == 2 assert remaining_tasks == ["task4", "task5"] - def test_push_and_pull_complex_objects(self, test_queue, fake): + def test_push_and_pull_complex_objects(self, test_queue, fake: Faker): """Test pushing and pulling complex object tasks.""" # Create complex task objects as dictionaries (not dataclass instances) tasks = [ @@ -218,7 +219,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert pulled_task["data"] == original_task["data"] assert pulled_task["metadata"] == original_task["metadata"] - def test_mixed_task_types(self, test_queue, fake): + def test_mixed_task_types(self, test_queue, fake: Faker): """Test pushing and pulling mixed string and object tasks.""" string_task = "simple_string_task" object_task = { @@ -267,7 +268,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Verify task key has expired assert test_queue.get_task_key() is None - def test_large_task_batch(self, test_queue, fake): + def test_large_task_batch(self, test_queue, fake: Faker): """Test handling large batches of tasks.""" # Create large batch of tasks large_batch = [] @@ -292,7 +293,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert isinstance(task, dict) assert task["index"] == i # FIFO order - def test_queue_operations_isolation(self, test_tenant_and_account, fake): + def test_queue_operations_isolation(self, test_tenant_and_account, fake: Faker): """Test concurrent operations on different queues.""" tenant, _ = test_tenant_and_account @@ -312,7 +313,7 @@ class TestTenantIsolatedTaskQueueIntegration: assert tasks2 == ["task1_queue2", "task2_queue2"] assert tasks1 != tasks2 - def test_task_wrapper_serialization_roundtrip(self, test_queue, fake): + def test_task_wrapper_serialization_roundtrip(self, test_queue, fake: Faker): """Test TaskWrapper serialization and deserialization roundtrip.""" # Create complex nested data complex_data = { @@ -346,7 +347,7 @@ class TestTenantIsolatedTaskQueueIntegration: task = test_queue.pull_tasks(1) assert task[0] == invalid_json_task - def test_real_world_batch_processing_scenario(self, test_queue, fake): + def test_real_world_batch_processing_scenario(self, test_queue, fake: Faker): """Test realistic batch processing scenario.""" # Simulate batch processing tasks batch_tasks = [] @@ -403,7 +404,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return Faker() @pytest.fixture - def test_tenant_and_account(self, db_session_with_containers, fake): + def test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker): """Create test tenant and account for testing.""" # Create account account = Account( @@ -435,7 +436,7 @@ class TestTenantIsolatedTaskQueueCompatibility: return tenant, account - def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake): + def test_legacy_string_queue_compatibility(self, test_tenant_and_account, fake: Faker): """ Test compatibility with legacy queues containing only string data. @@ -465,7 +466,7 @@ class TestTenantIsolatedTaskQueueCompatibility: expected_order = ["legacy_task_1", "legacy_task_2", "legacy_task_3", "legacy_task_4", "legacy_task_5"] assert pulled_tasks == expected_order - def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake): + def test_legacy_queue_migration_scenario(self, test_tenant_and_account, fake: Faker): """ Test complete migration scenario from legacy to new system. @@ -546,7 +547,7 @@ class TestTenantIsolatedTaskQueueCompatibility: assert task["tenant_id"] == tenant.id assert task["processing_type"] == "new_system" - def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake): + def test_legacy_queue_error_recovery(self, test_tenant_and_account, fake: Faker): """ Test error recovery when legacy queue contains malformed data. diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 00d7496a40..9da6b04a2c 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -15,7 +16,7 @@ from tests.test_containers_integration_tests.helpers import generate_valid_passw class TestGetAvailableDatasetsIntegration: def test_returns_datasets_with_available_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -77,7 +78,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].name == dataset.name def test_filters_out_datasets_with_only_archived_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -130,7 +131,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_only_disabled_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -183,7 +184,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_filters_out_datasets_with_non_completed_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -236,7 +237,7 @@ class TestGetAvailableDatasetsIntegration: assert len(result) == 0 def test_includes_external_datasets_without_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that external datasets are returned even with no available documents. @@ -280,7 +281,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].id == dataset.id assert result[0].provider == "external" - def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_filters_by_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): # Arrange fake = Faker() @@ -356,7 +357,7 @@ class TestGetAvailableDatasetsIntegration: assert result[0].tenant_id == tenant1.id def test_returns_empty_list_when_no_datasets_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -379,7 +380,9 @@ class TestGetAvailableDatasetsIntegration: # Assert assert result == [] - def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies): + def test_returns_only_requested_dataset_ids( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): # Arrange fake = Faker() @@ -439,7 +442,7 @@ class TestGetAvailableDatasetsIntegration: class TestKnowledgeRetrievalIntegration: def test_knowledge_retrieval_with_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -507,7 +510,7 @@ class TestKnowledgeRetrievalIntegration: assert isinstance(result, list) def test_knowledge_retrieval_no_available_datasets( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() @@ -555,7 +558,7 @@ class TestKnowledgeRetrievalIntegration: assert result == [] def test_knowledge_retrieval_rate_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): # Arrange fake = Faker() diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index 177fb95ff3..e71079829f 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -5,6 +5,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_service import ApiKeyAuthService @@ -31,7 +32,7 @@ class TestApiKeyAuthService: def mock_args(self, category, provider, mock_credentials) -> dict: return {"category": category, "provider": provider, "credentials": mock_credentials} - def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + def _create_binding(self, db_session: Session, *, tenant_id, category, provider, credentials=None, disabled=False): binding = DataSourceApiKeyAuthBinding( tenant_id=tenant_id, category=category, @@ -44,7 +45,7 @@ class TestApiKeyAuthService: return binding def test_get_provider_auth_list_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() @@ -56,14 +57,16 @@ class TestApiKeyAuthService: assert len(tenant_results) == 1 assert tenant_results[0].provider == provider - def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_get_provider_auth_list_empty( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): result = ApiKeyAuthService.get_provider_auth_list(tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] def test_get_provider_auth_list_filters_disabled( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True @@ -78,7 +81,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_success( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -97,7 +106,7 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_validation_failed( - self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, mock_factory, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_args ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = False @@ -112,7 +121,13 @@ class TestApiKeyAuthService: @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") def test_create_provider_auth_encrypts_api_key( - self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + self, + mock_encrypter, + mock_factory, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + mock_args, ): mock_auth_instance = Mock() mock_auth_instance.validate_credentials.return_value = True @@ -128,7 +143,13 @@ class TestApiKeyAuthService: mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) def test_get_auth_credentials_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + self, + flask_app_with_containers, + db_session_with_containers: Session, + tenant_id, + category, + provider, + mock_credentials, ): self._create_binding( db_session_with_containers, @@ -144,14 +165,14 @@ class TestApiKeyAuthService: assert result == mock_credentials def test_get_auth_credentials_not_found( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) assert result is None def test_get_auth_credentials_json_parsing( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} self._create_binding( @@ -169,7 +190,7 @@ class TestApiKeyAuthService: assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" def test_delete_provider_auth_success( - self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, category, provider ): binding = self._create_binding( db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider @@ -183,7 +204,9 @@ class TestApiKeyAuthService: remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() assert remaining is None - def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + def test_delete_provider_auth_not_found( + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id + ): # Should not raise when binding not found ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index f48c6da690..e78fa27976 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -10,6 +10,7 @@ from uuid import uuid4 import httpx import pytest +from sqlalchemy.orm import Session from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory @@ -114,7 +115,7 @@ class TestAuthIntegration: assert result2[0].tenant_id == tenant_id_2 def test_cross_tenant_access_prevention( - self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id_2, category ): result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index 42d587b7f7..327f14ddfe 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -12,6 +12,7 @@ from unittest.mock import create_autospec, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType @@ -273,7 +274,9 @@ class TestDocumentServicePauseDocument: "user_id": user_id, } - def test_pause_document_waiting_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_waiting_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in waiting state. @@ -310,7 +313,7 @@ class TestDocumentServicePauseDocument: mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(expected_cache_key, "True") def test_pause_document_indexing_state_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful pause of document in indexing state. @@ -340,7 +343,9 @@ class TestDocumentServicePauseDocument: assert document.is_paused is True assert document.paused_by == mock_document_service_dependencies["user_id"] - def test_pause_document_parsing_state_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_parsing_state_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful pause of document in parsing state. @@ -367,7 +372,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is True - def test_pause_document_completed_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_completed_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause completed document. @@ -396,7 +403,9 @@ class TestDocumentServicePauseDocument: db_session_with_containers.refresh(document) assert document.is_paused is False - def test_pause_document_error_state_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_pause_document_error_state_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to pause document in error state. @@ -467,7 +476,9 @@ class TestDocumentServiceRecoverDocument: "recover_task": mock_task, } - def test_recover_document_paused_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_paused_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful recovery of paused document. @@ -510,7 +521,9 @@ class TestDocumentServiceRecoverDocument: document.dataset_id, document.id ) - def test_recover_document_not_paused_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_recover_document_not_paused_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when trying to recover non-paused document. @@ -590,7 +603,9 @@ class TestDocumentServiceRetryDocument: "user_id": user_id, } - def test_retry_document_single_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_single_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of single document. @@ -629,7 +644,9 @@ class TestDocumentServiceRetryDocument: dataset.id, [document.id], mock_document_service_dependencies["user_id"] ) - def test_retry_document_multiple_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_retry_document_multiple_success( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test successful retry of multiple documents. @@ -675,7 +692,7 @@ class TestDocumentServiceRetryDocument: ) def test_retry_document_concurrent_retry_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is already being retried. @@ -708,7 +725,7 @@ class TestDocumentServiceRetryDocument: assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when current_user is missing. @@ -794,7 +811,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: } def test_batch_update_document_status_enable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch enabling of documents. @@ -844,7 +861,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: assert mock_document_service_dependencies["add_task"].delay.call_count == 2 def test_batch_update_document_status_disable_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch disabling of documents. @@ -886,7 +903,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_archive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch archiving of documents. @@ -928,7 +945,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_unarchive_success( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test successful batch unarchiving of documents. @@ -970,7 +987,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["add_task"].delay.assert_called_once_with(document.id) def test_batch_update_document_status_empty_list( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test handling of empty document list. @@ -996,7 +1013,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: mock_document_service_dependencies["remove_task"].delay.assert_not_called() def test_batch_update_document_status_document_indexing_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when document is being indexed. @@ -1073,7 +1090,7 @@ class TestDocumentServiceRenameDocument: "current_user": mock_current_user, } - def test_rename_document_success(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_success(self, db_session_with_containers: Session, mock_document_service_dependencies): """ Test successful document renaming. @@ -1111,7 +1128,9 @@ class TestDocumentServiceRenameDocument: assert result == document assert document.name == new_name - def test_rename_document_with_built_in_fields(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_built_in_fields( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with built-in fields enabled. @@ -1154,7 +1173,9 @@ class TestDocumentServiceRenameDocument: assert document.doc_metadata["document_name"] == new_name assert document.doc_metadata["existing_key"] == "existing_value" - def test_rename_document_with_upload_file(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_with_upload_file( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test document renaming with associated upload file. @@ -1202,7 +1223,7 @@ class TestDocumentServiceRenameDocument: assert upload_file.name == new_name def test_rename_document_dataset_not_found_error( - self, db_session_with_containers, mock_document_service_dependencies + self, db_session_with_containers: Session, mock_document_service_dependencies ): """ Test error when dataset is not found. @@ -1224,7 +1245,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Dataset not found"): DocumentService.rename_document(dataset_id, document_id, new_name) - def test_rename_document_not_found_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_not_found_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when document is not found. @@ -1251,7 +1274,9 @@ class TestDocumentServiceRenameDocument: with pytest.raises(ValueError, match="Document not found"): DocumentService.rename_document(dataset.id, document_id, new_name) - def test_rename_document_permission_error(self, db_session_with_containers, mock_document_service_dependencies): + def test_rename_document_permission_error( + self, db_session_with_containers: Session, mock_document_service_dependencies + ): """ Test error when user lacks permission. diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py index 4e8255d8ed..e73c2afe7f 100644 --- a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -11,6 +11,7 @@ from uuid import uuid4 import pytest from redis import RedisError +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from models.account import TenantAccountJoin @@ -122,7 +123,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_multiple_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -144,7 +145,7 @@ class TestSyncAccountDeletion: assert queued_workspace_ids == set(tenant_ids) def test_sync_account_deletion_no_workspaces( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: mock_config.ENTERPRISE_ENABLED = True @@ -155,7 +156,7 @@ class TestSyncAccountDeletion: mock_queue_task.assert_not_called() def test_sync_account_deletion_partial_failure( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_ids = [str(uuid4()) for _ in range(3)] @@ -180,7 +181,7 @@ class TestSyncAccountDeletion: assert mock_queue_task.call_count == 3 def test_sync_account_deletion_all_failures( - self, flask_app_with_containers, db_session_with_containers, mock_queue_task + self, flask_app_with_containers, db_session_with_containers: Session, mock_queue_task ): account_id = str(uuid4()) tenant_id = str(uuid4()) diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py index 2b842629a7..724dd19f92 100644 --- a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -3,6 +3,8 @@ from __future__ import annotations from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from models.model import App, RecommendedApp, Site from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval from services.recommend_app.recommend_app_type import RecommendAppType @@ -91,7 +93,7 @@ class TestDatabaseRecommendAppRetrieval: class TestFetchRecommendedAppsFromDb: - def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -111,7 +113,9 @@ class TestFetchRecommendedAppsFromDb: assert "assistant" in result["categories"] assert "writing" in result["categories"] - def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + def test_falls_back_to_default_language_when_empty( + self, flask_app_with_containers, db_session_with_containers: Session + ): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) @@ -124,7 +128,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id in app_ids - def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_site(db_session_with_containers, app_id=app1.id) @@ -137,7 +141,7 @@ class TestFetchRecommendedAppsFromDb: app_ids = {r["app_id"] for r in result["recommended_apps"]} assert app1.id not in app_ids - def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -151,12 +155,12 @@ class TestFetchRecommendedAppsFromDb: class TestFetchRecommendedAppDetailFromDb: - def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers: Session): result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) assert result is None - def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) _create_recommended_app(db_session_with_containers, app_id=app1.id) @@ -168,7 +172,7 @@ class TestFetchRecommendedAppDetailFromDb: assert result is None @patch("services.recommend_app.database.database_retrieval.AppDslService") - def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) _create_site(db_session_with_containers, app_id=app1.id) diff --git a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py index 3ec265d009..f78037e503 100644 --- a/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py +++ b/api/tests/test_containers_integration_tests/services/test_advanced_prompt_template_service.py @@ -2,6 +2,7 @@ import copy import pytest from faker import Faker +from sqlalchemy.orm import Session from core.prompt.prompt_templates.advanced_prompt_templates import ( BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, @@ -29,7 +30,9 @@ class TestAdvancedPromptTemplateService: # for consistency with other test files return {} - def test_get_prompt_baichuan_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_baichuan_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for Baichuan model. @@ -64,7 +67,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_prompt_common_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_common_model_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful prompt generation for common models. @@ -100,7 +105,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_prompt_case_insensitive_baichuan_detection( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan model detection is case insensitive. @@ -131,7 +136,7 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text def test_get_common_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for chat app with completion mode. @@ -161,7 +166,9 @@ class TestAdvancedPromptTemplateService: assert "{{#histories#}}" in prompt_text assert "{{#query#}}" in prompt_text - def test_get_common_prompt_chat_app_chat_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_chat_app_chat_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation for chat app with chat mode. @@ -189,7 +196,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with completion mode. @@ -217,7 +224,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_common_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation for completion app with chat mode. @@ -245,7 +252,9 @@ class TestAdvancedPromptTemplateService: assert CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_common_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_common_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test common prompt generation without context. @@ -273,7 +282,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_common_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported app mode. @@ -291,7 +300,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_common_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test common prompt generation with unsupported model mode. @@ -308,7 +317,9 @@ class TestAdvancedPromptTemplateService: # Assert: Verify empty dict is returned assert result == {} - def test_get_completion_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_completion_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test completion prompt generation with context. @@ -339,7 +350,7 @@ class TestAdvancedPromptTemplateService: assert result_text == CONTEXT + original_text def test_get_completion_prompt_without_context( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test completion prompt generation without context. @@ -368,7 +379,9 @@ class TestAdvancedPromptTemplateService: assert result_text == original_text assert CONTEXT not in result_text - def test_get_chat_prompt_with_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_with_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation with context. @@ -399,7 +412,9 @@ class TestAdvancedPromptTemplateService: assert original_text in result_text assert result_text == CONTEXT + original_text - def test_get_chat_prompt_without_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_chat_prompt_without_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test chat prompt generation without context. @@ -429,7 +444,7 @@ class TestAdvancedPromptTemplateService: assert CONTEXT not in result_text def test_get_baichuan_prompt_chat_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with completion mode. @@ -460,7 +475,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_chat_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for chat app with chat mode. @@ -489,7 +504,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_completion_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with completion mode. @@ -517,7 +532,7 @@ class TestAdvancedPromptTemplateService: assert "{{#pre_prompt#}}" in prompt_text def test_get_baichuan_prompt_completion_app_chat_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation for completion app with chat mode. @@ -545,7 +560,9 @@ class TestAdvancedPromptTemplateService: assert BAICHUAN_CONTEXT in prompt_text assert "{{#pre_prompt#}}" in prompt_text - def test_get_baichuan_prompt_no_context(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_baichuan_prompt_no_context( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test Baichuan prompt generation without context. @@ -573,7 +590,7 @@ class TestAdvancedPromptTemplateService: assert "{{#query#}}" in prompt_text def test_get_baichuan_prompt_unsupported_app_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported app mode. @@ -591,7 +608,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_baichuan_prompt_unsupported_model_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test Baichuan prompt generation with unsupported model mode. @@ -609,7 +626,7 @@ class TestAdvancedPromptTemplateService: assert result == {} def test_get_prompt_all_app_modes_common_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with common model. @@ -641,7 +658,7 @@ class TestAdvancedPromptTemplateService: assert result != {} def test_get_prompt_all_app_modes_baichuan_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prompt generation for all app modes with Baichuan model. @@ -672,7 +689,7 @@ class TestAdvancedPromptTemplateService: assert result is not None assert result != {} - def test_get_prompt_edge_cases(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_prompt_edge_cases(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test prompt generation with edge cases. @@ -704,7 +721,7 @@ class TestAdvancedPromptTemplateService: # Should either return a valid result or empty dict, but not crash assert result is not None - def test_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_template_immutability(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that original templates are not modified. @@ -738,7 +755,9 @@ class TestAdvancedPromptTemplateService: assert original_completion_completion == COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_completion_chat == COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_baichuan_template_immutability(self, db_session_with_containers, mock_external_service_dependencies): + def test_baichuan_template_immutability( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that original Baichuan templates are not modified. @@ -772,7 +791,9 @@ class TestAdvancedPromptTemplateService: assert original_baichuan_completion_completion == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG assert original_baichuan_completion_chat == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG - def test_context_integration_consistency(self, db_session_with_containers, mock_external_service_dependencies): + def test_context_integration_consistency( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test consistency of context integration across different scenarios. @@ -828,7 +849,7 @@ class TestAdvancedPromptTemplateService: assert prompt_text.startswith(CONTEXT) def test_baichuan_context_integration_consistency( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test consistency of Baichuan context integration across different scenarios. diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 1835650c42..6b844615b5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -10,6 +10,8 @@ from uuid import uuid4 import pytest import yaml from faker import Faker +from flask import Flask +from sqlalchemy.orm import Session from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, @@ -88,7 +90,7 @@ class TestAppDslService: """Integration tests for AppDslService using testcontainers.""" @pytest.fixture - def app(self, flask_app_with_containers): + def app(self, flask_app_with_containers: Flask): return flask_app_with_containers @pytest.fixture @@ -129,7 +131,7 @@ class TestAppDslService: "enterprise_service": mock_enterprise_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): fake = Faker() with patch("services.account_service.FeatureService") as mock_account_feature_service: mock_account_feature_service.get_system_features.return_value.is_allow_register = True @@ -206,7 +208,7 @@ class TestAppDslService: # ── Import: Validation ──────────────────────────────────────────── - def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers): + def test_import_app_invalid_import_mode_raises_value_error(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Invalid import_mode"): service.import_app( @@ -215,7 +217,7 @@ class TestAppDslService: yaml_content="version: '0.1.0'", ) - def test_import_app_missing_yaml_content(self, db_session_with_containers): + def test_import_app_missing_yaml_content(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -225,7 +227,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "yaml_content is required" in result.error - def test_import_app_missing_yaml_url(self, db_session_with_containers): + def test_import_app_missing_yaml_url(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -235,7 +237,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "yaml_url is required" in result.error - def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers): + def test_import_app_yaml_not_mapping_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -245,7 +247,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "content must be a mapping" in result.error - def test_import_app_version_not_str_returns_failed(self, db_session_with_containers): + def test_import_app_version_not_str_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) result = service.import_app( @@ -256,7 +258,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Invalid version type" in result.error - def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers): + def test_import_app_missing_app_data_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -266,7 +268,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Missing app data" in result.error - def test_import_app_yaml_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): def bad_safe_load(_content: str): raise yaml.YAMLError("bad") @@ -281,7 +283,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert result.error.startswith("Invalid YAML format:") - def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_unexpected_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): monkeypatch.setattr( AppDslService, "_create_or_update_app", @@ -299,7 +301,7 @@ class TestAppDslService: # ── Import: YAML URL ────────────────────────────────────────────── - def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_fetch_error_returns_failed(self, db_session_with_containers: Session, monkeypatch): monkeypatch.setattr( app_dsl_service.ssrf_proxy, "get", @@ -315,7 +317,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Error fetching YAML from URL: boom" in result.error - def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_empty_content_returns_failed(self, db_session_with_containers: Session, monkeypatch): response = MagicMock() response.content = b"" response.raise_for_status.return_value = None @@ -330,7 +332,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "Empty content" in result.error - def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_file_too_large_returns_failed(self, db_session_with_containers: Session, monkeypatch): response = MagicMock() response.content = b"x" * (DSL_MAX_SIZE + 1) response.raise_for_status.return_value = None @@ -345,7 +347,9 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "File size exceeds" in result.error - def test_import_app_yaml_url_user_attachments_keeps_original_url(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_user_attachments_keeps_original_url( + self, db_session_with_containers: Session, monkeypatch + ): yaml_url = "https://github.com/user-attachments/files/24290802/loop-test.yml" yaml_bytes = _pending_yaml_content() @@ -371,7 +375,7 @@ class TestAppDslService: assert result.imported_dsl_version == "99.0.0" assert requested_urls == [yaml_url] - def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers, monkeypatch): + def test_import_app_yaml_url_github_blob_rewrites_to_raw(self, db_session_with_containers: Session, monkeypatch): yaml_url = "https://github.com/acme/repo/blob/main/app.yml" raw_url = "https://raw.githubusercontent.com/acme/repo/main/app.yml" yaml_bytes = _pending_yaml_content() @@ -400,7 +404,7 @@ class TestAppDslService: # ── Import: App ID checks ──────────────────────────────────────── - def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers): + def test_import_app_app_id_not_found_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -412,7 +416,7 @@ class TestAppDslService: assert result.error == "App not found" def test_import_app_overwrite_only_allows_workflow_and_advanced_chat( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) assert app.mode == "chat" @@ -429,7 +433,7 @@ class TestAppDslService: # ── Import: Flow ────────────────────────────────────────────────── - def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers): + def test_import_app_pending_stores_import_info_in_redis(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.import_app( account=_account_mock(), @@ -449,7 +453,7 @@ class TestAppDslService: assert stored is not None def test_import_app_completed_uses_declared_dependencies( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) @@ -483,7 +487,7 @@ class TestAppDslService: @pytest.mark.parametrize("has_workflow", [True, False]) def test_import_app_legacy_versions_extract_dependencies( - self, db_session_with_containers, monkeypatch, has_workflow: bool + self, db_session_with_containers: Session, monkeypatch, has_workflow: bool ): monkeypatch.setattr( AppDslService, @@ -540,13 +544,13 @@ class TestAppDslService: # ── Confirm Import ──────────────────────────────────────────────── - def test_confirm_import_expired_returns_failed(self, db_session_with_containers): + def test_confirm_import_expired_returns_failed(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) result = service.confirm_import(import_id=str(uuid4()), account=_account_mock()) assert result.status == ImportStatus.FAILED assert "expired" in result.error - def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers, monkeypatch): + def test_confirm_import_success_deletes_redis_key(self, db_session_with_containers: Session, monkeypatch): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" @@ -579,7 +583,7 @@ class TestAppDslService: assert result.app_id == created_app.id assert redis_client.get(redis_key) is None - def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers): + def test_confirm_import_invalid_pending_data_type_returns_failed(self, db_session_with_containers: Session): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "123") @@ -589,7 +593,7 @@ class TestAppDslService: assert result.status == ImportStatus.FAILED assert "validation error" in result.error - def test_confirm_import_exception_returns_failed(self, db_session_with_containers): + def test_confirm_import_exception_returns_failed(self, db_session_with_containers: Session): import_id = str(uuid4()) redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" redis_client.setex(redis_key, IMPORT_INFO_REDIS_EXPIRY, "not-valid-json") @@ -600,13 +604,13 @@ class TestAppDslService: # ── Check Dependencies ──────────────────────────────────────────── - def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers): + def test_check_dependencies_returns_empty_when_no_redis_data(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) app_model = _app_stub() result = service.check_dependencies(app_model=app_model) assert result.leaked_dependencies == [] - def test_check_dependencies_calls_analysis_service(self, db_session_with_containers, monkeypatch): + def test_check_dependencies_calls_analysis_service(self, db_session_with_containers: Session, monkeypatch): app_id = str(uuid4()) pending = CheckDependenciesPendingData(dependencies=[], app_id=app_id) redis_client.setex( @@ -634,7 +638,9 @@ class TestAppDslService: result = service.check_dependencies(app_model=_app_stub(id=app_id)) assert len(result.leaked_dependencies) == 1 - def test_check_dependencies_with_real_app(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_dependencies_with_real_app( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) mock_dependencies_json = '{"app_id": "' + app.id + '", "dependencies": []}' @@ -650,12 +656,12 @@ class TestAppDslService: # ── Create/Update App ───────────────────────────────────────────── - def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers): + def test_create_or_update_app_missing_mode_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="loss app mode"): service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) - def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers, monkeypatch): + def test_create_or_update_app_existing_app_updates_fields(self, db_session_with_containers: Session, monkeypatch): fixed_now = object() monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) @@ -707,7 +713,7 @@ class TestAppDslService: assert app.icon_background == "#222222" assert app.updated_at is fixed_now - def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers): + def test_create_or_update_app_new_app_requires_tenant(self, db_session_with_containers: Session): account = _account_mock() account.current_tenant_id = None service = AppDslService(db_session_with_containers) @@ -719,7 +725,7 @@ class TestAppDslService: ) def test_create_or_update_app_creates_workflow_app_and_saves_dependencies( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): _, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) @@ -755,7 +761,7 @@ class TestAppDslService: stored = redis_client.get(f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{app.id}") assert stored is not None - def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers): + def test_create_or_update_app_workflow_missing_workflow_data_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing workflow data"): service._create_or_update_app( @@ -764,7 +770,7 @@ class TestAppDslService: account=_account_mock(), ) - def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers): + def test_create_or_update_app_chat_requires_model_config(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing model_config"): service._create_or_update_app( @@ -774,7 +780,7 @@ class TestAppDslService: ) def test_create_or_update_app_chat_creates_model_config_and_sends_event( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.app_model_config_id = None @@ -793,7 +799,7 @@ class TestAppDslService: db_session_with_containers.expire_all() assert app.app_model_config_id is not None - def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers): + def test_create_or_update_app_invalid_mode_raises(self, db_session_with_containers: Session): service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Invalid app mode"): service._create_or_update_app( @@ -870,7 +876,7 @@ class TestAppDslService: assert data["app"]["icon_type"] == "image" assert data["app"]["icon_background"] == "#FFEAD5" - def test_export_dsl_chat_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_chat_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) model_config = AppModelConfig( @@ -908,7 +914,9 @@ class TestAppDslService: assert "model_config" in exported_data assert "dependencies" in exported_data - def test_export_dsl_workflow_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_workflow_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" db_session_with_containers.commit() @@ -941,7 +949,9 @@ class TestAppDslService: assert "workflow" in exported_data assert "dependencies" in exported_data - def test_export_dsl_with_workflow_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_export_dsl_with_workflow_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" db_session_with_containers.commit() @@ -981,7 +991,7 @@ class TestAppDslService: assert "workflow" in exported_data def test_export_dsl_with_invalid_workflow_id_raises_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) app.mode = "workflow" diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py index 768a8baee2..d0c07f0de8 100644 --- a/api/tests/test_containers_integration_tests/services/test_attachment_service.py +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -7,7 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound import services.attachment_service as attachment_service_module @@ -19,7 +19,7 @@ from services.attachment_service import AttachmentService class TestAttachmentService: - def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + def _create_upload_file(self, db_session_with_containers: Session, *, tenant_id: str | None = None) -> UploadFile: upload_file = UploadFile( tenant_id=tenant_id or str(uuid4()), storage_type=StorageType.OPENDAL, @@ -60,7 +60,7 @@ class TestAttachmentService: with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): AttachmentService(session_factory=invalid_session_factory) - def test_should_return_base64_when_file_exists(self, db_session_with_containers): + def test_should_return_base64_when_file_exists(self, db_session_with_containers: Session): upload_file = self._create_upload_file(db_session_with_containers) service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) @@ -70,7 +70,7 @@ class TestAttachmentService: assert result == base64.b64encode(b"binary-content").decode() mock_load.assert_called_once_with(upload_file.key) - def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers: Session): service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) with patch.object(attachment_service_module.storage, "load_once") as mock_load: diff --git a/api/tests/test_containers_integration_tests/services/test_billing_service.py b/api/tests/test_containers_integration_tests/services/test_billing_service.py index 8092c7ad75..4893126d7f 100644 --- a/api/tests/test_containers_integration_tests/services/test_billing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_billing_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from extensions.ext_redis import redis_client @@ -24,7 +25,7 @@ class TestBillingServiceGetPlanBulkWithCache: """ @pytest.fixture(autouse=True) - def setup_redis_cleanup(self, flask_app_with_containers): + def setup_redis_cleanup(self, flask_app_with_containers: Flask): """Clean up Redis cache before and after each test.""" with flask_app_with_containers.app_context(): # Clean up before test @@ -56,7 +57,7 @@ class TestBillingServiceGetPlanBulkWithCache: return value return None - def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -87,7 +88,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was not called mock_get_plan_bulk.assert_not_called() - def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_all_cache_miss(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when all tenants are not in cache.""" with flask_app_with_containers.app_context(): # Arrange @@ -127,7 +128,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert ttl_1 > 0 assert ttl_1 <= 600 # Should be <= 600 seconds - def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_partial_cache_hit(self, flask_app_with_containers: Flask): """Test bulk plan retrieval when some tenants are in cache, some are not.""" with flask_app_with_containers.app_context(): # Arrange @@ -158,7 +159,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == missing_plan["tenant-3"] - def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_mget_failure(self, flask_app_with_containers: Flask): """Test fallback to API when Redis mget fails.""" with flask_app_with_containers.app_context(): # Arrange @@ -189,7 +190,7 @@ class TestBillingServiceGetPlanBulkWithCache: assert cached_1 is not None assert cached_2 is not None - def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_json_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache contains invalid JSON.""" with flask_app_with_containers.app_context(): # Arrange @@ -241,7 +242,7 @@ class TestBillingServiceGetPlanBulkWithCache: cached_data_3 = json.loads(cached_3) assert cached_data_3 == expected_plans["tenant-3"] - def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_invalid_plan_data_in_cache(self, flask_app_with_containers: Flask): """Test fallback to API when cache data doesn't match SubscriptionPlan schema.""" with flask_app_with_containers.app_context(): # Arrange @@ -274,7 +275,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify API was called for tenant-2 and tenant-3 mock_get_plan_bulk.assert_called_once_with(["tenant-2", "tenant-3"]) - def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_redis_pipeline_failure(self, flask_app_with_containers: Flask): """Test that pipeline failure doesn't affect return value.""" with flask_app_with_containers.app_context(): # Arrange @@ -303,7 +304,7 @@ class TestBillingServiceGetPlanBulkWithCache: # Verify pipeline was attempted mock_pipeline.assert_called_once() - def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_empty_tenant_ids(self, flask_app_with_containers: Flask): """Test with empty tenant_ids list.""" with flask_app_with_containers.app_context(): # Act @@ -321,7 +322,7 @@ class TestBillingServiceGetPlanBulkWithCache: # But we should check that mget was not called at all # Since we can't easily verify this without more mocking, we just verify the result - def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers): + def test_get_plan_bulk_with_cache_ttl_expired(self, flask_app_with_containers: Flask): """Test that expired cache keys are treated as cache misses.""" with flask_app_with_containers.app_context(): # Arrange diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 98c38f2b5f..8aa10129c1 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -170,7 +171,7 @@ class ConversationServiceIntegrationTestDataFactory: class TestConversationServicePagination: """Test conversation pagination operations.""" - def test_pagination_with_non_empty_include_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_include_ids(self, db_session_with_containers: Session): """ Test that non-empty include_ids filters properly. @@ -204,7 +205,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[0].id, conversations[1].id} - def test_pagination_with_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that empty exclude_ids doesn't filter. @@ -237,7 +238,7 @@ class TestConversationServicePagination: # Assert assert len(result.data) == len(conversations) - def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers): + def test_pagination_with_non_empty_exclude_ids(self, db_session_with_containers: Session): """ Test that non-empty exclude_ids filters properly. @@ -271,7 +272,7 @@ class TestConversationServicePagination: returned_ids = {conversation.id for conversation in result.data} assert returned_ids == {conversations[2].id} - def test_pagination_with_sorting_descending(self, db_session_with_containers): + def test_pagination_with_sorting_descending(self, db_session_with_containers: Session): """ Test pagination with descending sort order. @@ -316,7 +317,7 @@ class TestConversationServiceMessageCreation: within conversations. """ - def test_pagination_by_first_id_without_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_without_first_id(self, db_session_with_containers: Session): """ Test message pagination without specifying first_id. @@ -354,7 +355,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 3 # All 3 messages returned assert result.has_more is False # No more messages available (3 < limit of 10) - def test_pagination_by_first_id_with_first_id(self, db_session_with_containers): + def test_pagination_by_first_id_with_first_id(self, db_session_with_containers: Session): """ Test message pagination with first_id specified. @@ -399,7 +400,9 @@ class TestConversationServiceMessageCreation: assert len(result.data) == 2 # Only 2 messages returned after first_id assert result.has_more is False # No more messages available (2 < limit of 10) - def test_pagination_by_first_id_raises_error_when_first_message_not_found(self, db_session_with_containers): + def test_pagination_by_first_id_raises_error_when_first_message_not_found( + self, db_session_with_containers: Session + ): """ Test that FirstMessageNotExistsError is raised when first_id doesn't exist. @@ -424,7 +427,7 @@ class TestConversationServiceMessageCreation: limit=10, ) - def test_pagination_with_has_more_flag(self, db_session_with_containers): + def test_pagination_with_has_more_flag(self, db_session_with_containers: Session): """ Test that has_more flag is correctly set when there are more messages. @@ -463,7 +466,7 @@ class TestConversationServiceMessageCreation: assert len(result.data) == limit # Extra message should be removed assert result.has_more is True # Flag should be set - def test_pagination_with_ascending_order(self, db_session_with_containers): + def test_pagination_with_ascending_order(self, db_session_with_containers: Session): """ Test message pagination with ascending order. @@ -512,7 +515,7 @@ class TestConversationServiceSummarization: """ @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_success(self, mock_llm_generator, db_session_with_containers: Session): """ Test successful auto-generation of conversation name. @@ -552,7 +555,7 @@ class TestConversationServiceSummarization: app_model.tenant_id, first_message.query, conversation.id, app_model.id ) - def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers): + def test_auto_generate_name_raises_error_when_no_message(self, db_session_with_containers: Session): """ Test that MessageNotExistsError is raised when conversation has no messages. @@ -571,7 +574,9 @@ class TestConversationServiceSummarization: ConversationService.auto_generate_name(app_model, conversation) @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_auto_generate_name_handles_llm_failure_gracefully(self, mock_llm_generator, db_session_with_containers): + def test_auto_generate_name_handles_llm_failure_gracefully( + self, mock_llm_generator, db_session_with_containers: Session + ): """ Test that LLM generation failures are suppressed and don't crash. @@ -604,7 +609,7 @@ class TestConversationServiceSummarization: assert conversation.name == original_name # Name remains unchanged @patch("services.conversation_service.naive_utc_now") - def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers): + def test_rename_with_manual_name(self, mock_naive_utc_now, db_session_with_containers: Session): """ Test renaming conversation with manual name. @@ -638,7 +643,7 @@ class TestConversationServiceSummarization: assert conversation.updated_at == mock_time @patch("services.conversation_service.LLMGenerator.generate_conversation_name") - def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers): + def test_rename_with_auto_generate(self, mock_llm_generator, db_session_with_containers: Session): """ Test rename delegates to auto_generate_name when auto_generate is True. @@ -682,7 +687,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_from_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_from_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating annotation from existing message. @@ -721,7 +728,9 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_create_annotation_without_message(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_create_annotation_without_message( + self, mock_current_account, mock_add_task, db_session_with_containers: Session + ): """ Test creating standalone annotation without message. @@ -753,7 +762,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_update_existing_annotation(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test updating an existing annotation. @@ -800,7 +809,7 @@ class TestConversationServiceMessageAnnotation: mock_add_task.delay.assert_not_called() @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving paginated annotation list. @@ -836,7 +845,7 @@ class TestConversationServiceMessageAnnotation: assert result_total == 5 @patch("services.annotation_service.current_account_with_tenant") - def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers): + def test_get_annotation_list_with_keyword_search(self, mock_current_account, db_session_with_containers: Session): """ Test retrieving annotations with keyword filtering. @@ -885,7 +894,7 @@ class TestConversationServiceMessageAnnotation: @patch("services.annotation_service.add_annotation_to_index_task") @patch("services.annotation_service.current_account_with_tenant") - def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers): + def test_insert_annotation_directly(self, mock_current_account, mock_add_task, db_session_with_containers: Session): """ Test direct annotation insertion without message reference. @@ -919,7 +928,7 @@ class TestConversationServiceExport: Tests retrieving conversation data for export purposes. """ - def test_get_conversation_success(self, db_session_with_containers): + def test_get_conversation_success(self, db_session_with_containers: Session): """Test successful retrieval of conversation.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -937,7 +946,7 @@ class TestConversationServiceExport: # Assert assert result == conversation - def test_get_conversation_not_found(self, db_session_with_containers): + def test_get_conversation_not_found(self, db_session_with_containers: Session): """Test ConversationNotExistsError when conversation doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -949,7 +958,7 @@ class TestConversationServiceExport: ConversationService.get_conversation(app_model=app_model, conversation_id=str(uuid4()), user=user) @patch("services.annotation_service.current_account_with_tenant") - def test_export_annotation_list(self, mock_current_account, db_session_with_containers): + def test_export_annotation_list(self, mock_current_account, db_session_with_containers: Session): """Test exporting all annotations for an app.""" # Arrange app_model, account = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -977,7 +986,7 @@ class TestConversationServiceExport: # Assert assert len(result) == 10 - def test_get_message_success(self, db_session_with_containers): + def test_get_message_success(self, db_session_with_containers: Session): """Test successful retrieval of a message.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -1001,7 +1010,7 @@ class TestConversationServiceExport: # Assert assert result == message - def test_get_message_not_found(self, db_session_with_containers): + def test_get_message_not_found(self, db_session_with_containers: Session): """Test MessageNotExistsError when message doesn't exist.""" # Arrange app_model, user = ConversationServiceIntegrationTestDataFactory.create_app_and_account( @@ -1012,7 +1021,7 @@ class TestConversationServiceExport: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app_model, user=user, message_id=str(uuid4())) - def test_get_conversation_for_end_user(self, db_session_with_containers): + def test_get_conversation_for_end_user(self, db_session_with_containers: Session): """ Test retrieving conversation created by end user via API. @@ -1038,7 +1047,7 @@ class TestConversationServiceExport: assert result == conversation @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation(self, mock_delete_task, db_session_with_containers: Session): """ Test conversation deletion with async cleanup. @@ -1071,7 +1080,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_called_once_with(conversation_id) @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers): + def test_delete_conversation_not_owned_by_account(self, mock_delete_task, db_session_with_containers: Session): """ Test deletion is denied when conversation belongs to a different account. """ @@ -1102,7 +1111,7 @@ class TestConversationServiceExport: mock_delete_task.delay.assert_not_called() @patch("services.conversation_service.delete_conversation_related_data") - def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers): + def test_delete_handles_exception_and_rollback(self, mock_delete_task, db_session_with_containers: Session): """ Test that delete propagates exceptions and does not trigger the cleanup task. diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py index 0b7bd9ca64..6c292dbc4b 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service_variables.py @@ -5,7 +5,8 @@ from unittest.mock import patch from uuid import uuid4 import pytest -from sqlalchemy.orm import sessionmaker +from flask import Flask +from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db @@ -149,7 +150,7 @@ class ConversationServiceVariableIntegrationFactory: @pytest.fixture -def real_conversation_service_session_factory(flask_app_with_containers): +def real_conversation_service_session_factory(flask_app_with_containers: Flask): del flask_app_with_containers real_session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) @@ -162,7 +163,7 @@ def real_conversation_service_session_factory(flask_app_with_containers): class TestConversationServiceVariables: def test_get_conversational_variable_success( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -200,7 +201,7 @@ class TestConversationServiceVariables: assert result.has_more is False def test_get_conversational_variable_with_last_id( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -242,7 +243,7 @@ class TestConversationServiceVariables: assert result.has_more is False def test_get_conversational_variable_last_id_not_found_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -259,7 +260,7 @@ class TestConversationServiceVariables: ) def test_get_conversational_variable_sets_has_more( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -287,7 +288,7 @@ class TestConversationServiceVariables: assert result.has_more is True def test_update_conversation_variable_success( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -320,7 +321,7 @@ class TestConversationServiceVariables: assert result["updated_at"] == updated_at def test_update_conversation_variable_not_found_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -337,7 +338,7 @@ class TestConversationServiceVariables: ) def test_update_conversation_variable_type_mismatch_raises_error( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -360,7 +361,7 @@ class TestConversationServiceVariables: ) def test_update_conversation_variable_integer_number_compatibility( - self, db_session_with_containers, real_conversation_service_session_factory + self, db_session_with_containers: Session, real_conversation_service_session_factory ): del real_conversation_service_session_factory factory = ConversationServiceVariableIntegrationFactory @@ -390,7 +391,7 @@ class TestConversationServiceVariables: class TestConversationServicePaginationWithContainers: - def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers): + def test_pagination_by_last_id_raises_error_when_last_id_missing(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) @@ -404,7 +405,7 @@ class TestConversationServicePaginationWithContainers: invoke_from=InvokeFrom.WEB_APP, ) - def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers): + def test_pagination_by_last_id_with_default_desc_updated_at(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) base_time = datetime(2024, 1, 1, 8, 0, 0) @@ -442,7 +443,7 @@ class TestConversationServicePaginationWithContainers: assert newest.id != middle.id assert [conversation.id for conversation in result.data] == [oldest.id] - def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers): + def test_pagination_by_last_id_with_name_sort(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) alpha = factory.create_conversation(db_session_with_containers, app, account, name="Alpha") @@ -462,7 +463,7 @@ class TestConversationServicePaginationWithContainers: assert alpha.id != beta.id assert [conversation.id for conversation in result.data] == [gamma.id] - def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers): + def test_pagination_filters_to_end_user_api_source(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) end_user = factory.create_end_user(db_session_with_containers, app) @@ -493,7 +494,7 @@ class TestConversationServicePaginationWithContainers: assert account_conversation.id != end_user_conversation.id assert [conversation.id for conversation in result.data] == [end_user_conversation.id] - def test_pagination_filters_to_account_console_source(self, db_session_with_containers): + def test_pagination_filters_to_account_console_source(self, db_session_with_containers: Session): factory = ConversationServiceVariableIntegrationFactory app, account = factory.create_app_and_account(db_session_with_containers) end_user = factory.create_end_user(db_session_with_containers, app) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 02ab3f8314..638a962f18 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,7 +3,7 @@ from uuid import uuid4 import pytest -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from extensions.ext_database import db from graphon.variables import StringVariable @@ -13,7 +13,12 @@ from services.conversation_variable_updater import ConversationVariableNotFoundE class TestConversationVariableUpdater: def _create_conversation_variable( - self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + self, + db_session_with_containers: Session, + *, + conversation_id: str, + variable: StringVariable, + app_id: str | None = None, ) -> ConversationVariable: row = ConversationVariable( id=variable.id, @@ -25,7 +30,7 @@ class TestConversationVariableUpdater: db_session_with_containers.commit() return row - def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="old value") self._create_conversation_variable( @@ -42,7 +47,7 @@ class TestConversationVariableUpdater: assert row is not None assert row.data == updated_variable.model_dump_json() - def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers: Session): conversation_id = str(uuid4()) variable = StringVariable(id=str(uuid4()), name="topic", value="value") updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) @@ -50,7 +55,7 @@ class TestConversationVariableUpdater: with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): updater.update(conversation_id=conversation_id, variable=variable) - def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers: Session): updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) result = updater.flush() diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py index 0f63d98642..09ba041244 100644 --- a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -3,6 +3,7 @@ from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.errors.error import QuotaExceededError from models import TenantCreditPool @@ -14,7 +15,7 @@ class TestCreditPoolService: def _create_tenant_id(self) -> str: return str(uuid4()) - def test_create_default_pool(self, db_session_with_containers): + def test_create_default_pool(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) @@ -25,7 +26,7 @@ class TestCreditPoolService: assert pool.quota_used == 0 assert pool.quota_limit > 0 - def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -35,17 +36,17 @@ class TestCreditPoolService: assert result.tenant_id == tenant_id assert result.pool_type == ProviderQuotaType.TRIAL - def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers: Session): result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type=ProviderQuotaType.TRIAL) assert result is None - def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers: Session): result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) assert result is False - def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) @@ -53,7 +54,7 @@ class TestCreditPoolService: assert result is True - def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) # Exhaust credits @@ -64,11 +65,11 @@ class TestCreditPoolService: assert result is False - def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers: Session): with pytest.raises(QuotaExceededError, match="Credit pool not found"): CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) - def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) pool.quota_used = pool.quota_limit @@ -77,7 +78,7 @@ class TestCreditPoolService: with pytest.raises(QuotaExceededError, match="No credits remaining"): CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) - def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() CreditPoolService.create_default_pool(tenant_id) credits_required = 10 @@ -89,7 +90,7 @@ class TestCreditPoolService: pool = CreditPoolService.get_pool(tenant_id=tenant_id) assert pool.quota_used == credits_required - def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers: Session): tenant_id = self._create_tenant_id() pool = CreditPoolService.create_default_pool(tenant_id) remaining = 5 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index 71c8874f79..f9898e2cfa 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -8,6 +8,7 @@ checks with testcontainers-backed infrastructure instead of database-chain mocks from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType from extensions.ext_database import db @@ -107,7 +108,7 @@ class DatasetPermissionTestDataFactory: class TestDatasetPermissionServiceGetPartialMemberList: """Verify partial-member list reads against persisted DatasetPermission rows.""" - def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers: Session): """ Test retrieving partial member list with multiple members. """ @@ -138,7 +139,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 3 - def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers: Session): """ Test retrieving partial member list with single member. """ @@ -160,7 +161,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: assert set(result) == set(expected_account_ids) assert len(result) == 1 - def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: Session): """ Test retrieving partial member list when no members exist. """ @@ -179,7 +180,7 @@ class TestDatasetPermissionServiceGetPartialMemberList: class TestDatasetPermissionServiceUpdatePartialMemberList: """Verify partial-member list updates against persisted DatasetPermission rows.""" - def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + def test_update_partial_member_list_add_new_members(self, db_session_with_containers: Session): """ Test adding new partial members to a dataset. """ @@ -203,7 +204,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {member_1.id, member_2.id} - def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session): """ Test replacing existing partial members with new ones. """ @@ -239,7 +240,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert set(result) == {new_member_1.id, new_member_2.id} - def test_update_partial_member_list_empty_list(self, db_session_with_containers): + def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test updating with empty member list (clearing all members). """ @@ -264,7 +265,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -313,7 +314,7 @@ class TestDatasetPermissionServiceUpdatePartialMemberList: class TestDatasetPermissionServiceClearPartialMemberList: """Verify partial-member clearing against persisted DatasetPermission rows.""" - def test_clear_partial_member_list_success(self, db_session_with_containers): + def test_clear_partial_member_list_success(self, db_session_with_containers: Session): """ Test successful clearing of partial member list. """ @@ -338,7 +339,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session): """ Test clearing partial member list when no members exist. """ @@ -353,7 +354,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert result == [] - def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): """ Test error handling and rollback on database error. """ @@ -398,7 +399,7 @@ class TestDatasetPermissionServiceClearPartialMemberList: class TestDatasetServiceCheckDatasetPermission: """Verify dataset access checks against persisted partial-member permissions.""" - def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers: Session): """Test that users from different tenants cannot access dataset.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -410,7 +411,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other_user) - def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers: Session): """Test that tenant owners can access any dataset regardless of permission level.""" owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -423,7 +424,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, owner) - def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers: Session): """Test ONLY_ME permission allows only the dataset creator to access.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) @@ -433,7 +434,7 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, creator) - def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers: Session): """Test ONLY_ME permission denies access to non-creators.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -447,7 +448,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError): DatasetService.check_dataset_permission(dataset, other) - def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers: Session): """Test ALL_TEAM permission allows any team member to access the dataset.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( @@ -460,7 +461,9 @@ class TestDatasetServiceCheckDatasetPermission: DatasetService.check_dataset_permission(dataset, member) - def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_with_permission_success( + self, db_session_with_containers: Session + ): """ Test that user with explicit permission can access partial_members dataset. """ @@ -485,7 +488,9 @@ class TestDatasetServiceCheckDatasetPermission: permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) assert user.id in permissions - def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + def test_check_dataset_permission_partial_members_without_permission_error( + self, db_session_with_containers: Session + ): """ Test error when user without permission tries to access partial_members dataset. """ @@ -506,7 +511,7 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers: Session): """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0de3c64c4f..e6ee896a52 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -712,7 +712,7 @@ class TestDatasetServiceRetrievalConfiguration: class TestDocumentServicePauseRecoverRetry: """Tests for pause/recover/retry orchestration using real DB and Redis.""" - def _create_indexing_document(self, db_session_with_containers, indexing_status="indexing"): + def _create_indexing_document(self, db_session_with_containers: Session, indexing_status="indexing"): factory = DatasetServiceIntegrationDataFactory account, tenant = factory.create_account_with_tenant(db_session_with_containers) dataset = factory.create_dataset(db_session_with_containers, tenant.id, account.id) @@ -721,7 +721,7 @@ class TestDocumentServicePauseRecoverRetry: db_session_with_containers.commit() return doc, account - def test_pause_document_success(self, db_session_with_containers): + def test_pause_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -740,7 +740,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is not None redis_client.delete(cache_key) - def test_pause_document_invalid_status_error(self, db_session_with_containers): + def test_pause_document_invalid_status_error(self, db_session_with_containers: Session): from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -751,7 +751,7 @@ class TestDocumentServicePauseRecoverRetry: with pytest.raises(DocumentIndexingError): DocumentService.pause_document(doc) - def test_recover_document_success(self, db_session_with_containers): + def test_recover_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService @@ -775,7 +775,7 @@ class TestDocumentServicePauseRecoverRetry: assert redis_client.get(cache_key) is None recover_task.delay.assert_called_once_with(doc.dataset_id, doc.id) - def test_retry_document_indexing_success(self, db_session_with_containers): + def test_retry_document_indexing_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client from services.dataset_service import DocumentService diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py index c486ff5613..08de79f4b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -6,6 +6,7 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from models.account import Account, Tenant, TenantAccountJoin from services.dataset_service import DatasetService @@ -48,7 +49,7 @@ class TestDatasetServiceCreateRagPipelineDataset: permission="only_me", ) - def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers): + def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, db_session_with_containers: Session): tenant, _ = self._create_tenant_and_account(db_session_with_containers) mock_user = Mock(id=None) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index 3cac964d89..c43a5d5978 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -3,6 +3,8 @@ from unittest.mock import patch from uuid import uuid4 +from sqlalchemy.orm import Session + from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document @@ -101,7 +103,7 @@ class DatasetDeleteIntegrationDataFactory: class TestDatasetServiceDeleteDataset: """Integration coverage for DatasetService.delete_dataset using testcontainers.""" - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -144,7 +146,7 @@ class TestDatasetServiceDeleteDataset: dataset.pipeline_id, ) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -172,7 +174,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -200,7 +202,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers: Session): """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" # Arrange owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) @@ -228,7 +230,7 @@ class TestDatasetServiceDeleteDataset: assert db_session_with_containers.get(Dataset, dataset.id) is None clean_dataset_delay.assert_not_called() - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """Return False without scheduling cleanup when the target dataset does not exist.""" # Arrange owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py index 1b4179c9c7..0603a1e27f 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -6,6 +6,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -363,7 +364,7 @@ class TestDatasetServicePermissionsAndLifecycle: DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) - def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers): + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask): with flask_app_with_containers.app_context(): with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(str(uuid4()), True) @@ -473,7 +474,7 @@ class TestDatasetCollectionBindingServiceIntegration: assert persisted.type == "dataset" assert persisted.collection_name - def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask): with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index fe426ae516..69c39b8bfb 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -6,6 +6,7 @@ from datetime import UTC, datetime, timedelta from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from graphon.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -46,7 +47,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.commit() return run - def _create_archive_log(self, db_session_with_containers, *, run: WorkflowRun) -> None: + def _create_archive_log(self, db_session_with_containers: Session, *, run: WorkflowRun) -> None: archive_log = WorkflowArchiveLog( tenant_id=run.tenant_id, app_id=run.app_id, @@ -72,7 +73,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.add(archive_log) db_session_with_containers.commit() - def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_run_missing(self, db_session_with_containers: Session): deleter = ArchivedWorkflowRunDeletion() missing_run_id = str(uuid4()) @@ -81,7 +82,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {missing_run_id} not found" - def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers): + def test_delete_by_run_id_returns_error_when_not_archived(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -95,7 +96,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == f"Workflow run {run.id} is not archived" - def test_delete_batch_uses_repo(self, db_session_with_containers): + def test_delete_batch_uses_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) base_time = datetime.now(UTC) run1 = self._create_workflow_run(db_session_with_containers, tenant_id=tenant_id, created_at=base_time) @@ -124,7 +125,7 @@ class TestArchivedWorkflowRunDeletion: ).all() assert remaining_runs == [] - def test_delete_run_calls_repo(self, db_session_with_containers): + def test_delete_run_calls_repo(self, db_session_with_containers: Session): tenant_id = str(uuid4()) run = self._create_workflow_run( db_session_with_containers, @@ -142,7 +143,7 @@ class TestArchivedWorkflowRunDeletion: deleted_run = db_session_with_containers.get(WorkflowRun, run_id) assert deleted_run is None - def test_delete_run_dry_run(self, db_session_with_containers): + def test_delete_run_dry_run(self, db_session_with_containers: Session): """Dry run should return success without actually deleting.""" tenant_id = str(uuid4()) run = self._create_workflow_run( @@ -161,7 +162,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(WorkflowRun, run_id) is not None - def test_delete_run_exception_returns_error(self, db_session_with_containers): + def test_delete_run_exception_returns_error(self, db_session_with_containers: Session): """Exception during deletion should return failure result.""" from unittest.mock import MagicMock, patch @@ -183,7 +184,7 @@ class TestArchivedWorkflowRunDeletion: assert result.success is False assert result.error == "Database error" - def test_delete_by_run_id_success(self, db_session_with_containers): + def test_delete_by_run_id_success(self, db_session_with_containers: Session): """Successfully delete an archived workflow run by ID.""" tenant_id = str(uuid4()) base_time = datetime.now(UTC) @@ -202,7 +203,7 @@ class TestArchivedWorkflowRunDeletion: db_session_with_containers.expunge_all() assert db_session_with_containers.get(WorkflowRun, run_id) is None - def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers): + def test_get_workflow_run_repo_caches_instance(self, db_session_with_containers: Session): """_get_workflow_run_repo should return a cached repo on subsequent calls.""" deleter = ArchivedWorkflowRunDeletion() diff --git a/api/tests/test_containers_integration_tests/services/test_end_user_service.py b/api/tests/test_containers_integration_tests/services/test_end_user_service.py index cafabc939b..074d448aab 100644 --- a/api/tests/test_containers_integration_tests/services/test_end_user_service.py +++ b/api/tests/test_containers_integration_tests/services/test_end_user_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin @@ -102,7 +103,7 @@ class TestEndUserServiceGetOrCreateEndUser: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_with_custom_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user with custom user_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -118,7 +119,7 @@ class TestEndUserServiceGetOrCreateEndUser: assert result.type == InvokeFrom.SERVICE_API assert result.is_anonymous is False - def test_get_or_create_end_user_without_user_id(self, db_session_with_containers, factory): + def test_get_or_create_end_user_without_user_id(self, db_session_with_containers: Session, factory): """Test getting or creating end user without user_id uses default session.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -131,7 +132,7 @@ class TestEndUserServiceGetOrCreateEndUser: # Verify _is_anonymous is set correctly (property always returns False) assert result._is_anonymous is True - def test_get_existing_end_user(self, db_session_with_containers, factory): + def test_get_existing_end_user(self, db_session_with_containers: Session, factory): """Test retrieving an existing end user.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -167,7 +168,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_create_end_user_service_api_type(self, db_session_with_containers, factory): + def test_create_end_user_service_api_type(self, db_session_with_containers: Session, factory): """Test creating new end user with SERVICE_API type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -189,7 +190,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.app_id == app_id assert result.session_id == user_id - def test_create_end_user_web_app_type(self, db_session_with_containers, factory): + def test_create_end_user_web_app_type(self, db_session_with_containers: Session, factory): """Test creating new end user with WEB_APP type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -209,7 +210,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.WEB_APP @patch("services.end_user_service.logger") - def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers, factory): + def test_upgrade_legacy_end_user_type(self, mock_logger, db_session_with_containers: Session, factory): """Test upgrading legacy end user with different type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -243,7 +244,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert "Upgrading legacy EndUser" in log_call @patch("services.end_user_service.logger") - def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers, factory): + def test_get_existing_end_user_matching_type(self, mock_logger, db_session_with_containers: Session, factory): """Test retrieving existing end user with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -272,7 +273,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.type == InvokeFrom.SERVICE_API mock_logger.info.assert_not_called() - def test_create_anonymous_user_with_default_session(self, db_session_with_containers, factory): + def test_create_anonymous_user_with_default_session(self, db_session_with_containers: Session, factory): """Test creating anonymous user when user_id is None.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -293,7 +294,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result._is_anonymous is True assert result.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID - def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers, factory): + def test_query_ordering_prioritizes_matching_type(self, db_session_with_containers: Session, factory): """Test that query ordering prioritizes records with matching type.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -328,7 +329,7 @@ class TestEndUserServiceGetOrCreateEndUserByType: assert result.id == matching.id assert result.id != non_matching.id - def test_external_user_id_matches_session_id(self, db_session_with_containers, factory): + def test_external_user_id_matches_session_id(self, db_session_with_containers: Session, factory): """Test that external_user_id is set to match session_id.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -357,7 +358,9 @@ class TestEndUserServiceGetOrCreateEndUserByType: InvokeFrom.DEBUGGER, ], ) - def test_create_end_user_with_different_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_end_user_with_different_invoke_types( + self, db_session_with_containers: Session, invoke_type, factory + ): """Test creating end users with different InvokeFrom types.""" # Arrange app = factory.create_app_and_account(db_session_with_containers) @@ -385,7 +388,7 @@ class TestEndUserServiceGetEndUserById: """Provide test data factory.""" return TestEndUserServiceFactory() - def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_end_user(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) existing_user = factory.create_end_user( db_session_with_containers, @@ -404,7 +407,7 @@ class TestEndUserServiceGetEndUserById: assert result is not None assert result.id == existing_user.id - def test_get_end_user_by_id_returns_none(self, db_session_with_containers, factory): + def test_get_end_user_by_id_returns_none(self, db_session_with_containers: Session, factory): app = factory.create_app_and_account(db_session_with_containers) result = EndUserService.get_end_user_by_id( @@ -423,7 +426,7 @@ class TestEndUserServiceCreateBatch: def factory(self): return TestEndUserServiceFactory() - def _create_multiple_apps(self, db_session_with_containers, factory, count: int = 3): + def _create_multiple_apps(self, db_session_with_containers: Session, factory, count: int = 3): """Create multiple apps under the same tenant.""" first_app = factory.create_app_and_account(db_session_with_containers) tenant_id = first_app.tenant_id @@ -452,13 +455,13 @@ class TestEndUserServiceCreateBatch: all_apps = db_session_with_containers.query(App).filter(App.tenant_id == tenant_id).all() return tenant_id, all_apps - def test_create_batch_empty_app_ids(self, db_session_with_containers): + def test_create_batch_empty_app_ids(self, db_session_with_containers: Session): result = EndUserService.create_end_user_batch( type=InvokeFrom.SERVICE_API, tenant_id=str(uuid4()), app_ids=[], user_id="user-1" ) assert result == {} - def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers, factory): + def test_create_batch_creates_users_for_all_apps(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -473,7 +476,7 @@ class TestEndUserServiceCreateBatch: assert result[app_id].session_id == user_id assert result[app_id].type == InvokeFrom.SERVICE_API - def test_create_batch_default_session_id(self, db_session_with_containers, factory): + def test_create_batch_default_session_id(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] @@ -486,7 +489,7 @@ class TestEndUserServiceCreateBatch: assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID assert end_user._is_anonymous is True - def test_create_batch_deduplicate_app_ids(self, db_session_with_containers, factory): + def test_create_batch_deduplicate_app_ids(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [apps[0].id, apps[1].id, apps[0].id, apps[1].id] user_id = f"user-{uuid4()}" @@ -497,7 +500,7 @@ class TestEndUserServiceCreateBatch: assert len(result) == 2 - def test_create_batch_returns_existing_users(self, db_session_with_containers, factory): + def test_create_batch_returns_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=2) app_ids = [a.id for a in apps] user_id = f"user-{uuid4()}" @@ -516,7 +519,7 @@ class TestEndUserServiceCreateBatch: for app_id in app_ids: assert first_result[app_id].id == second_result[app_id].id - def test_create_batch_partial_existing_users(self, db_session_with_containers, factory): + def test_create_batch_partial_existing_users(self, db_session_with_containers: Session, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=3) user_id = f"user-{uuid4()}" @@ -545,7 +548,7 @@ class TestEndUserServiceCreateBatch: "invoke_type", [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER], ) - def test_create_batch_all_invoke_types(self, db_session_with_containers, invoke_type, factory): + def test_create_batch_all_invoke_types(self, db_session_with_containers: Session, invoke_type, factory): tenant_id, apps = self._create_multiple_apps(db_session_with_containers, factory, count=1) user_id = f"user-{uuid4()}" diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index 315936d721..f78aeaf984 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan from services.feature_service import ( @@ -81,7 +82,7 @@ class TestFeatureService: fake = Faker() return fake.uuid4() - def test_get_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful feature retrieval with billing and enterprise enabled. @@ -156,7 +157,7 @@ class TestFeatureService: tenant_id ) - def test_get_features_sandbox_plan(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_sandbox_plan(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval for sandbox plan with specific limitations. @@ -222,7 +223,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_knowledge_rate_limit_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_knowledge_rate_limit_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful knowledge rate limit retrieval with billing enabled. @@ -255,7 +258,7 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system features retrieval with enterprise and marketplace enabled. @@ -332,7 +335,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_system_features_unauthenticated(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_unauthenticated( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval for an unauthenticated user. @@ -386,7 +391,9 @@ class TestFeatureService: # Marketplace should be visible assert result.enable_marketplace is True - def test_get_system_features_basic_config(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_basic_config( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with basic configuration (no enterprise). @@ -436,7 +443,9 @@ class TestFeatureService: # Verify plugin package size (uses default value from dify_config) assert result.max_plugin_package_size == 15728640 - def test_get_features_billing_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_billing_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval when billing is disabled. @@ -492,7 +501,7 @@ class TestFeatureService: assert result.webapp_copyright_enabled is False def test_get_knowledge_rate_limit_billing_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test knowledge rate limit retrieval when billing is disabled. @@ -523,7 +532,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_knowledge_rate_limit.assert_not_called() - def test_get_features_enterprise_only(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_enterprise_only( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with enterprise enabled but billing disabled. @@ -583,7 +594,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_not_called() def test_get_system_features_enterprise_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval when enterprise is disabled. @@ -640,7 +651,7 @@ class TestFeatureService: # Verify no enterprise service calls mock_external_service_dependencies["enterprise_service"].get_info.assert_not_called() - def test_get_features_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test feature retrieval without tenant ID (billing disabled). @@ -686,7 +697,9 @@ class TestFeatureService: # Verify no billing service calls mock_external_service_dependencies["billing_service"].get_info.assert_not_called() - def test_get_features_partial_billing_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_partial_billing_info( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with partial billing information. @@ -746,7 +759,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) - def test_get_features_edge_case_vector_space(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_vector_space( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case vector space configuration. @@ -807,7 +822,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_webapp_auth( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case webapp auth configuration. @@ -863,7 +878,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_members_quota(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_members_quota( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case members quota configuration. @@ -924,7 +941,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_plugin_installation_permission_scopes( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with different plugin installation permission scopes. @@ -1023,7 +1040,7 @@ class TestFeatureService: assert result.plugin_installation_permission.restrict_to_marketplace_only is True def test_get_features_workspace_members_missing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members info is missing from enterprise. @@ -1064,7 +1081,9 @@ class TestFeatureService: tenant_id ) - def test_get_system_features_license_inactive(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_inactive( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with inactive license. @@ -1117,7 +1136,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_system_features_partial_enterprise_info( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with partial enterprise information. @@ -1186,7 +1205,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_limits(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_limits( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case limit values. @@ -1244,7 +1265,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_protocols( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case protocol values. @@ -1297,7 +1318,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() - def test_get_features_edge_case_education(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_features_edge_case_education( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test feature retrieval with edge case education configuration. @@ -1353,7 +1376,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_license_limitation_model_is_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test LicenseLimitationModel.is_available method with various scenarios. @@ -1394,7 +1417,7 @@ class TestFeatureService: assert exact_limit.is_available(3) is True def test_get_features_workspace_members_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval when workspace members are disabled in enterprise. @@ -1433,7 +1456,9 @@ class TestFeatureService: # Verify mock interactions mock_external_service_dependencies["enterprise_service"].get_workspace_info.assert_called_once_with(tenant_id) - def test_get_system_features_license_expired(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_system_features_license_expired( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test system features retrieval with expired license. @@ -1486,7 +1511,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_docs_processing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case document processing configuration. @@ -1544,7 +1569,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_branding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features retrieval with edge case branding configuration. @@ -1606,7 +1631,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_annotation_quota( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case annotation quota configuration. @@ -1668,7 +1693,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_features_edge_case_documents_upload( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with edge case documents upload settings. @@ -1733,7 +1758,7 @@ class TestFeatureService: mock_external_service_dependencies["billing_service"].get_info.assert_called_once_with(tenant_id) def test_get_system_features_edge_case_license_lost( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test system features with lost license status. @@ -1784,7 +1809,7 @@ class TestFeatureService: mock_external_service_dependencies["enterprise_service"].get_info.assert_called_once() def test_get_features_edge_case_education_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test feature retrieval with education feature disabled. diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index ed75363f3b..ce63e7a71a 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -6,6 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session from configs import dify_config from core.workflow.human_input_adapter import ( @@ -88,7 +89,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, flask_app_with_containers, db_session_with_containers): + def test_default(self, flask_app_with_containers, db_session_with_containers: Session): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -260,7 +261,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account = Account(name="Test User", email="member@example.com") db_session_with_containers.add(account) @@ -282,7 +283,7 @@ class TestEmailDeliveryTestHandler: ) assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers: Session): tenant_id = str(uuid4()) account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py index b55a19eaa9..fffa82bf5c 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -5,6 +5,7 @@ from uuid import uuid4 import pytest from sqlalchemy import select +from sqlalchemy.orm import Session from models.dataset import Dataset, DatasetMetadataBinding, Document from models.enums import DataSourceType, DocumentCreatedFrom @@ -65,7 +66,7 @@ class TestMetadataPartialUpdate: yield account def test_partial_update_merges_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -92,7 +93,7 @@ class TestMetadataPartialUpdate: assert updated_doc.doc_metadata["new_key"] == "new_value" def test_full_update_replaces_metadata( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -119,7 +120,7 @@ class TestMetadataPartialUpdate: assert "existing_key" not in updated_doc.doc_metadata def test_partial_update_skips_existing_binding( - self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, user_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( @@ -159,7 +160,7 @@ class TestMetadataPartialUpdate: assert len(bindings) == 1 def test_rollback_called_on_commit_failure( - self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + self, flask_app_with_containers, db_session_with_containers: Session, tenant_id, mock_current_account ): dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) document = _create_document( diff --git a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py index c146a5924b..5fa5de6d80 100644 --- a/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py +++ b/api/tests/test_containers_integration_tests/services/test_oauth_server_service.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from models.model import OAuthProviderApp @@ -25,7 +26,7 @@ from services.oauth_server import ( class TestOAuthServerServiceGetProviderApp: """DB-backed tests for get_oauth_provider_app.""" - def _create_oauth_provider_app(self, db_session_with_containers, *, client_id: str) -> OAuthProviderApp: + def _create_oauth_provider_app(self, db_session_with_containers: Session, *, client_id: str) -> OAuthProviderApp: app = OAuthProviderApp( app_icon="icon.png", client_id=client_id, @@ -38,7 +39,7 @@ class TestOAuthServerServiceGetProviderApp: db_session_with_containers.commit() return app - def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_app_when_exists(self, db_session_with_containers: Session): client_id = f"client-{uuid4()}" created = self._create_oauth_provider_app(db_session_with_containers, client_id=client_id) @@ -48,7 +49,7 @@ class TestOAuthServerServiceGetProviderApp: assert result.client_id == client_id assert result.id == created.id - def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers): + def test_get_oauth_provider_app_returns_none_when_not_exists(self, db_session_with_containers: Session): result = OAuthServerService.get_oauth_provider_app(f"nonexistent-{uuid4()}") assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index 7036524918..2f20949611 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -8,6 +8,7 @@ from datetime import datetime from uuid import uuid4 from sqlalchemy import select +from sqlalchemy.orm import Session from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore @@ -39,7 +40,7 @@ class TestWorkflowRunRestore: assert result["created_at"].month == 1 assert result["name"] == "test" - def test_restore_table_records_returns_rowcount(self, db_session_with_containers): + def test_restore_table_records_returns_rowcount(self, db_session_with_containers: Session): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() record_id = str(uuid4()) @@ -65,7 +66,7 @@ class TestWorkflowRunRestore: restored_pause = db_session_with_containers.scalar(select(WorkflowPause).where(WorkflowPause.id == record_id)) assert restored_pause is not None - def test_restore_table_records_unknown_table(self, db_session_with_containers): + def test_restore_table_records_unknown_table(self, db_session_with_containers: Session): """Unknown table names should be ignored gracefully.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index 5a6bf0466e..583b6128e6 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -1099,38 +1099,39 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - # Create tag - tag = self._create_test_tags( - db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 1 - )[0] + # Create tags + tags = self._create_test_tags( + db_session_with_containers, mock_external_service_dependencies, tenant.id, "knowledge", 2 + ) - # Create dataset and bind tag + # Create dataset and bind tags dataset = self._create_test_dataset(db_session_with_containers, mock_external_service_dependencies, tenant.id) self._create_test_tag_bindings( - db_session_with_containers, mock_external_service_dependencies, [tag], dataset.id, tenant.id + db_session_with_containers, mock_external_service_dependencies, tags, dataset.id, tenant.id ) - # Verify binding exists before deletion - - binding_before = ( + # Verify bindings exist before deletion + bindings_before = ( db_session_with_containers.query(TagBinding) - .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) - .first() + .where(TagBinding.tag_id.in_([tag.id for tag in tags]), TagBinding.target_id == dataset.id) + .all() ) - assert binding_before is not None + assert len(bindings_before) == 2 # Act: Execute the method under test - delete_payload = TagBindingDeletePayload(type="knowledge", target_id=dataset.id, tag_id=tag.id) + delete_payload = TagBindingDeletePayload( + type="knowledge", target_id=dataset.id, tag_ids=[tag.id for tag in tags] + ) TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes - # Verify tag binding was deleted - binding_after = ( + # Verify tag bindings were deleted + bindings_after = ( db_session_with_containers.query(TagBinding) - .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) - .first() + .where(TagBinding.tag_id.in_([tag.id for tag in tags]), TagBinding.target_id == dataset.id) + .all() ) - assert binding_after is None + assert len(bindings_after) == 0 def test_delete_tag_binding_non_existent_binding( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -1156,7 +1157,7 @@ class TestTagService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, tenant.id) # Act: Try to delete non-existent binding - delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_id=tag.id) + delete_payload = TagBindingDeletePayload(type="app", target_id=app.id, tag_ids=[tag.id]) TagService.delete_tag_binding(delete_payload) # Assert: Verify the expected outcomes diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 970da98c55..6d5c7380b7 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from flask import Flask +from sqlalchemy.orm import Session from werkzeug.datastructures import FileStorage from models.enums import AppTriggerStatus, AppTriggerType @@ -52,7 +53,7 @@ class TestWebhookService: } @pytest.fixture - def test_data(self, db_session_with_containers, mock_external_dependencies): + def test_data(self, db_session_with_containers: Session, mock_external_dependencies): """Create test data for webhook service tests.""" fake = Faker() @@ -160,7 +161,7 @@ class TestWebhookService: "app_trigger": app_trigger, } - def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_success(self, test_data, flask_app_with_containers: Flask): """Test successful retrieval of webhook trigger and workflow.""" webhook_id = test_data["webhook_id"] @@ -175,7 +176,7 @@ class TestWebhookService: assert node_config["id"] == "webhook_node" assert node_config["data"].title == "Test Webhook" - def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): + def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers: Flask): """Test webhook trigger not found scenario.""" with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Webhook not found"): @@ -421,7 +422,9 @@ class TestWebhookService: assert result["files"] == {} - def test_trigger_workflow_execution_success(self, test_data, mock_external_dependencies, flask_app_with_containers): + def test_trigger_workflow_execution_success( + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask + ): """Test successful workflow execution trigger.""" webhook_data = { "method": "POST", @@ -452,7 +455,7 @@ class TestWebhookService: mock_external_dependencies["async_service"].trigger_workflow_async.assert_called_once() def test_trigger_workflow_execution_end_user_service_failure( - self, test_data, mock_external_dependencies, flask_app_with_containers + self, test_data, mock_external_dependencies, flask_app_with_containers: Flask ): """Test workflow execution trigger when EndUserService fails.""" webhook_data = {"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}} diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py index 85ce3a6ba6..69cde847f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service_relationships.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from flask import Flask from sqlalchemy import select from sqlalchemy.orm import Session @@ -165,7 +166,7 @@ class WebhookServiceRelationshipFactory: class TestWebhookServiceLookupWithContainers: def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_missing( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -182,7 +183,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_rate_limited( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -202,7 +203,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_app_trigger_disabled( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -222,7 +223,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_raises_when_workflow_missing( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -239,7 +240,7 @@ class TestWebhookServiceLookupWithContainers: WebhookService.get_webhook_trigger_and_workflow(webhook_trigger.webhook_id) def test_get_webhook_trigger_and_workflow_returns_debug_draft_workflow( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -275,7 +276,7 @@ class TestWebhookServiceLookupWithContainers: class TestWebhookServiceTriggerExecutionWithContainers: def test_trigger_workflow_execution_triggers_async_workflow_successfully( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -318,7 +319,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: assert trigger_args[2].root_node_id == webhook_trigger.node_id def test_trigger_workflow_execution_marks_tenant_rate_limited_when_quota_exceeded( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -354,7 +355,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: mock_mark_rate_limited.assert_called_once_with(tenant.id) def test_trigger_workflow_execution_logs_and_reraises_unexpected_errors( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -386,7 +387,7 @@ class TestWebhookServiceTriggerExecutionWithContainers: class TestWebhookServiceRelationshipSyncWithContainers: def test_sync_webhook_relationships_raises_when_workflow_exceeds_node_limit( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -401,7 +402,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: WebhookService.sync_webhook_relationships(app, workflow) def test_sync_webhook_relationships_raises_when_lock_not_acquired( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -418,7 +419,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: WebhookService.sync_webhook_relationships(app, workflow) def test_sync_webhook_relationships_creates_missing_records_and_deletes_stale_records( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -455,7 +456,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert db_session_with_containers.get(WorkflowWebhookTrigger, stale_trigger_id) is None def test_sync_webhook_relationships_sets_redis_cache_for_new_record( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory @@ -481,7 +482,7 @@ class TestWebhookServiceRelationshipSyncWithContainers: assert cached_payload["webhook_id"] == "cache-webhook-id-00001" def test_sync_webhook_relationships_logs_when_lock_release_fails( - self, db_session_with_containers: Session, flask_app_with_containers + self, db_session_with_containers: Session, flask_app_with_containers: Flask ): del flask_app_with_containers factory = WebhookServiceRelationshipFactory diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 1e57b5603d..a2cdddad61 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -1530,7 +1530,7 @@ class TestWorkflowAppService: assert result_cross_tenant["total"] == 0 def test_get_paginate_workflow_app_logs_raises_when_account_filter_email_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1543,7 +1543,7 @@ class TestWorkflowAppService: ) def test_get_paginate_workflow_app_logs_filters_by_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() @@ -1558,7 +1558,9 @@ class TestWorkflowAppService: assert result["total"] >= 0 assert isinstance(result["data"], list) - def test_get_paginate_workflow_archive_logs(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_archive_logs( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) service = WorkflowAppService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 86cf2327c7..82fe391b08 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -45,7 +45,9 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, fake: Faker | None = None + ): """ Helper method to create a test app with realistic data for testing. @@ -80,7 +82,7 @@ class TestWorkflowDraftVariableService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index b5ce8a53de..9ba1fda08b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -12,7 +12,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from models import Account, App, Workflow +from models import Account, AccountStatus, App, TenantStatus, Workflow from models.model import AppMode from models.workflow import WorkflowType from services.workflow_service import WorkflowService @@ -33,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -49,7 +49,7 @@ class TestWorkflowService: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", # Set interface language for Site creation ) account.created_at = fake.date_time_this_year() @@ -62,7 +62,7 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() @@ -77,7 +77,7 @@ class TestWorkflowService: return account - def _create_test_app(self, db_session_with_containers: Session, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test app with realistic data. @@ -109,7 +109,7 @@ class TestWorkflowService: db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake: Faker | None = None): """ Helper method to create a test workflow associated with an app. diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py index 29e1e240b4..afc4908c15 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -100,7 +100,7 @@ class TestWorkflowDeletion: session.flush() return provider - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -118,7 +118,7 @@ class TestWorkflowDeletion: db_session_with_containers.expire_all() assert db_session_with_containers.get(Workflow, workflow_id) is None - def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + def test_delete_draft_workflow_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -130,7 +130,7 @@ class TestWorkflowDeletion: with pytest.raises(DraftWorkflowDeletionError): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( @@ -144,7 +144,7 @@ class TestWorkflowDeletion: with pytest.raises(WorkflowInUseError, match="currently in use by app"): service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) - def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers: Session): tenant, account = self._create_tenant_and_account(db_session_with_containers) app = self._create_app(db_session_with_containers, tenant=tenant, account=account) workflow = self._create_workflow( diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index 4dab895135..32b76c3469 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -64,7 +64,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: db_session_with_containers.commit() return execution - def test_get_node_last_execution_found(self, db_session_with_containers): + def test_get_node_last_execution_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it exists.""" # Arrange tenant_id = str(uuid4()) @@ -110,7 +110,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result.id == expected.id assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - def test_get_node_last_execution_not_found(self, db_session_with_containers): + def test_get_node_last_execution_not_found(self, db_session_with_containers: Session): """Test getting the last execution for a node when it doesn't exist.""" # Arrange tenant_id = str(uuid4()) @@ -129,7 +129,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_get_executions_by_workflow_run_empty(self, db_session_with_containers): + def test_get_executions_by_workflow_run_empty(self, db_session_with_containers: Session): """Test getting executions for a workflow run when none exist.""" # Arrange tenant_id = str(uuid4()) @@ -147,7 +147,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result == [] - def test_get_execution_by_id_found(self, db_session_with_containers): + def test_get_execution_by_id_found(self, db_session_with_containers: Session): """Test getting execution by ID when it exists.""" # Arrange execution = self._create_execution( @@ -170,7 +170,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert result is not None assert result.id == execution.id - def test_get_execution_by_id_not_found(self, db_session_with_containers): + def test_get_execution_by_id_not_found(self, db_session_with_containers: Session): """Test getting execution by ID when it doesn't exist.""" # Arrange repository = self._create_repository(db_session_with_containers) @@ -182,7 +182,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: # Assert assert result is None - def test_delete_expired_executions(self, db_session_with_containers): + def test_delete_expired_executions(self, db_session_with_containers: Session): """Test deleting expired executions.""" # Arrange tenant_id = str(uuid4()) @@ -248,7 +248,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_2_id not in remaining_ids assert kept_execution_id in remaining_ids - def test_delete_executions_by_app(self, db_session_with_containers): + def test_delete_executions_by_app(self, db_session_with_containers: Session): """Test deleting executions by app.""" # Arrange tenant_id = str(uuid4()) @@ -313,7 +313,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert deleted_2_id not in remaining_ids assert kept_id in remaining_ids - def test_get_expired_executions_batch(self, db_session_with_containers): + def test_get_expired_executions_batch(self, db_session_with_containers: Session): """Test getting expired executions batch for backup.""" # Arrange tenant_id = str(uuid4()) @@ -370,7 +370,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: assert old_execution_1.id in result_ids assert old_execution_2.id in result_ids - def test_delete_executions_by_ids(self, db_session_with_containers): + def test_delete_executions_by_ids(self, db_session_with_containers: Session): """Test deleting executions by IDs.""" # Arrange tenant_id = str(uuid4()) @@ -424,7 +424,7 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: ).all() assert remaining == [] - def test_delete_executions_by_ids_empty_list(self, db_session_with_containers): + def test_delete_executions_by_ids_empty_list(self, db_session_with_containers: Session): """Test deleting executions with empty ID list.""" # Arrange repository = self._create_repository(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 7e5c374b5d..1c8d5969e0 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -71,7 +71,7 @@ class TestCleanNotionDocumentTask: yield mock_factory def test_clean_notion_document_task_success( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test successful cleanup of Notion documents with proper database operations. @@ -176,7 +176,7 @@ class TestCleanNotionDocumentTask: # 5. The task completes without errors def test_clean_notion_document_task_dataset_not_found( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior when dataset is not found. @@ -196,7 +196,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.return_value.init_index_processor.assert_not_called() def test_clean_notion_document_task_empty_document_list( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task behavior with empty document list. @@ -240,7 +240,7 @@ class TestCleanNotionDocumentTask: assert args[1] == [] def test_clean_notion_document_task_with_different_index_types( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with different dataset index types. @@ -328,7 +328,7 @@ class TestCleanNotionDocumentTask: mock_index_processor_factory.reset_mock() def test_clean_notion_document_task_with_segments_no_index_node_ids( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments that have no index_node_ids. @@ -411,7 +411,7 @@ class TestCleanNotionDocumentTask: # are properly deleted from the database. def test_clean_notion_document_task_partial_document_cleanup( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with partial document cleanup scenario. @@ -513,7 +513,7 @@ class TestCleanNotionDocumentTask: # The database operations work correctly, isolating only the specified documents. def test_clean_notion_document_task_with_mixed_segment_statuses( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with segments in different statuses. @@ -603,7 +603,7 @@ class TestCleanNotionDocumentTask: # IndexProcessor verification would require more sophisticated mocking. def test_clean_notion_document_task_continues_when_index_processor_fails( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Index processor failure (e.g. transient billing API error propagated via @@ -707,7 +707,7 @@ class TestCleanNotionDocumentTask: assert _count_segments(db_session_with_containers, DocumentSegment.document_id == document.id) == 0 def test_clean_notion_document_task_with_large_number_of_documents( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with a large number of documents and segments. @@ -806,7 +806,7 @@ class TestCleanNotionDocumentTask: # The database efficiently handles large-scale deletions. def test_clean_notion_document_task_with_documents_from_different_tenants( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents from different tenants. @@ -918,7 +918,7 @@ class TestCleanNotionDocumentTask: # Only documents from the target dataset are affected, maintaining tenant separation. def test_clean_notion_document_task_with_documents_in_different_states( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents in different indexing states. @@ -1024,7 +1024,7 @@ class TestCleanNotionDocumentTask: # All documents are deleted regardless of their indexing status. def test_clean_notion_document_task_with_documents_having_metadata( - self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_index_processor_factory, mock_external_service_dependencies ): """ Test cleanup task with documents that have rich metadata. diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 9084667c31..80289c448a 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -12,6 +12,7 @@ from uuid import uuid4 import pytest from faker import Faker from sqlalchemy import delete +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from extensions.ext_redis import redis_client @@ -25,7 +26,7 @@ class TestCreateSegmentToIndexTask: """Integration tests for create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database and Redis before each test to ensure isolation.""" # Clear all test data using fixture session @@ -55,7 +56,7 @@ class TestCreateSegmentToIndexTask: "index_processor": mock_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -102,7 +103,7 @@ class TestCreateSegmentToIndexTask: return account, tenant - def _create_test_dataset_and_document(self, db_session_with_containers, tenant_id, account_id): + def _create_test_dataset_and_document(self, db_session_with_containers: Session, tenant_id, account_id): """ Helper method to create a test dataset and document for testing. @@ -151,7 +152,13 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING + self, + db_session_with_containers: Session, + dataset_id, + document_id, + tenant_id, + account_id, + status=SegmentStatus.WAITING, ): """ Helper method to create a test document segment for testing. @@ -189,7 +196,9 @@ class TestCreateSegmentToIndexTask: return segment - def test_create_segment_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful creation of segment to index. @@ -225,7 +234,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_segment_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent segment ID. @@ -246,7 +255,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_invalid_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with invalid status. @@ -277,7 +286,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_dataset(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated dataset. @@ -330,7 +341,9 @@ class TestCreateSegmentToIndexTask: # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() - def test_create_segment_to_index_no_document(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_segment_to_index_no_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test handling of segment without associated document. @@ -367,7 +380,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with disabled document. @@ -403,7 +416,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_archived( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with archived document. @@ -439,7 +452,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_document_indexing_incomplete( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of segment with document that has incomplete indexing. @@ -475,7 +488,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_not_called() def test_create_segment_to_index_processor_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of index processor exceptions. @@ -511,7 +524,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 0 def test_create_segment_to_index_with_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with custom keywords. @@ -543,7 +556,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with different document forms. @@ -586,7 +599,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) def test_create_segment_to_index_performance_timing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing performance and timing. @@ -617,7 +630,7 @@ class TestCreateSegmentToIndexTask: assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test concurrent execution of segment indexing tasks. @@ -654,7 +667,7 @@ class TestCreateSegmentToIndexTask: assert mock_external_service_dependencies["index_processor_factory"].call_count == 3 def test_create_segment_to_index_large_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with large content. @@ -703,7 +716,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_redis_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing when Redis operations fail. @@ -743,7 +756,7 @@ class TestCreateSegmentToIndexTask: assert redis_client.exists(cache_key) == 1 def test_create_segment_to_index_database_transaction_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with database transaction handling. @@ -775,7 +788,7 @@ class TestCreateSegmentToIndexTask: assert segment.error is not None def test_create_segment_to_index_metadata_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with metadata validation. @@ -817,7 +830,7 @@ class TestCreateSegmentToIndexTask: assert doc is not None def test_create_segment_to_index_status_transition_flow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test complete status transition flow during indexing. @@ -852,7 +865,7 @@ class TestCreateSegmentToIndexTask: assert segment.indexing_at <= segment.completed_at def test_create_segment_to_index_with_empty_content( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with empty or minimal content. @@ -894,7 +907,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with special characters and unicode content. @@ -940,7 +953,7 @@ class TestCreateSegmentToIndexTask: assert segment.completed_at is not None def test_create_segment_to_index_with_long_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with long keyword lists. @@ -974,7 +987,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with proper tenant isolation. @@ -1017,7 +1030,7 @@ class TestCreateSegmentToIndexTask: assert segment1.tenant_id != segment2.tenant_id def test_create_segment_to_index_with_none_keywords( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment indexing with None keywords parameter. @@ -1048,7 +1061,7 @@ class TestCreateSegmentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() def test_create_segment_to_index_comprehensive_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Comprehensive integration test covering multiple scenarios. diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 684097851b..a5a3cd10b5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -174,11 +175,11 @@ class TestDatasetIndexingTaskIntegration: return dataset, documents - def _query_document(self, db_session_with_containers, document_id: str) -> Document | None: + def _query_document(self, db_session_with_containers: Session, document_id: str) -> Document | None: """Return the latest persisted document state.""" return db_session_with_containers.scalar(select(Document).where(Document.id == document_id).limit(1)) - def _assert_documents_parsing(self, db_session_with_containers, document_ids: Sequence[str]) -> None: + def _assert_documents_parsing(self, db_session_with_containers: Session, document_ids: Sequence[str]) -> None: """Assert all target documents are persisted in parsing status.""" db_session_with_containers.expire_all() for document_id in document_ids: @@ -212,7 +213,9 @@ class TestDatasetIndexingTaskIntegration: assert len(opened) >= 2 assert opened_ids <= closed_ids - def test_legacy_document_indexing_task_still_works(self, db_session_with_containers, patched_external_dependencies): + def test_legacy_document_indexing_task_still_works( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Ensure the legacy task entrypoint still updates parsing status.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -225,7 +228,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_multiple_documents(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_multiple_documents( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Process multiple documents in one batch.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -240,7 +245,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == len(document_ids) self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_batch_processing_with_limit_check(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_with_limit_check( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Reject batches larger than configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. @@ -263,7 +270,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "batch upload limit") def test_batch_processing_sandbox_plan_single_document_only( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Reject multi-document upload under sandbox plan.""" # Arrange @@ -280,7 +287,9 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() self._assert_documents_error_contains(db_session_with_containers, document_ids, "does not support batch upload") - def test_batch_processing_empty_document_list(self, db_session_with_containers, patched_external_dependencies): + def test_batch_processing_empty_document_list( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle empty list input without failing.""" # Arrange dataset, _ = self._create_test_dataset_and_documents(db_session_with_containers, document_count=0) @@ -292,7 +301,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_tenant_queue_dispatches_next_task_after_completion( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch the next queued task after current tenant task completes. @@ -337,7 +346,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_not_called() def test_tenant_queue_deletes_running_key_when_no_follow_up_tasks( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Delete tenant running flag when queue has no pending tasks. @@ -362,7 +371,7 @@ class TestDatasetIndexingTaskIntegration: delete_key_spy.assert_called_once() def test_validation_failure_sets_error_status_when_vector_space_at_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Set error status when vector space validation fails before runner phase.""" # Arrange @@ -382,7 +391,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_error_contains(db_session_with_containers, document_ids, "over the limit") def test_runner_exception_does_not_crash_indexing_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Catch generic runner exceptions without crashing the task.""" # Arrange @@ -397,7 +406,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_document_paused_error_handling(self, db_session_with_containers, patched_external_dependencies): + def test_document_paused_error_handling(self, db_session_with_containers: Session, patched_external_dependencies): """Handle DocumentIsPausedError and keep persisted state consistent.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=2) @@ -424,7 +433,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_not_called() def test_tenant_queue_error_handling_still_processes_next_task( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Even on current task failure, enqueue the next waiting tenant task. @@ -491,7 +500,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_all_opened_sessions_closed(session_close_tracker) def test_multiple_documents_with_mixed_success_and_failure( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Process only existing documents when request includes missing ids.""" # Arrange @@ -508,7 +517,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, existing_ids) def test_tenant_queue_dispatches_up_to_concurrency_limit( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Dispatch only up to configured concurrency under queued backlog burst. @@ -543,7 +552,7 @@ class TestDatasetIndexingTaskIntegration: assert task_dispatch_spy.apply_async.call_count == concurrency_limit assert set_waiting_spy.call_count == concurrency_limit - def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies): + def test_task_queue_fifo_ordering(self, db_session_with_containers: Session, patched_external_dependencies): """Keep FIFO ordering when dispatching next queued tasks. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -576,7 +585,9 @@ class TestDatasetIndexingTaskIntegration: call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {}) assert call_kwargs.get("document_ids") == expected_task["document_ids"] - def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies): + def test_billing_disabled_skips_limit_checks( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Skip limit checks when billing feature is disabled.""" # Arrange large_document_ids = [str(uuid.uuid4()) for _ in range(100)] @@ -595,7 +606,7 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 100 self._assert_documents_parsing(db_session_with_containers, large_document_ids) - def test_complete_workflow_normal_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_normal_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end normal queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -618,7 +629,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_complete_workflow_priority_task(self, db_session_with_containers, patched_external_dependencies): + def test_complete_workflow_priority_task(self, db_session_with_containers: Session, patched_external_dependencies): """Run end-to-end priority queue workflow with tenant queue cleanup. Queue APIs are patched to isolate dispatch side effects while preserving DB assertions. @@ -641,7 +652,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) delete_key_spy.assert_called_once() - def test_single_document_processing(self, db_session_with_containers, patched_external_dependencies): + def test_single_document_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process the minimum batch size (single document).""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=1) @@ -655,7 +666,9 @@ class TestDatasetIndexingTaskIntegration: assert len(run_args) == 1 self._assert_documents_parsing(db_session_with_containers, [document_id]) - def test_document_with_special_characters_in_id(self, db_session_with_containers, patched_external_dependencies): + def test_document_with_special_characters_in_id( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Handle standard UUID ids with hyphen characters safely.""" # Arrange special_document_id = str(uuid.uuid4()) @@ -670,7 +683,9 @@ class TestDatasetIndexingTaskIntegration: # Assert self._assert_documents_parsing(db_session_with_containers, [special_document_id]) - def test_zero_vector_space_limit_allows_unlimited(self, db_session_with_containers, patched_external_dependencies): + def test_zero_vector_space_limit_allows_unlimited( + self, db_session_with_containers: Session, patched_external_dependencies + ): """Treat vector limit 0 as unlimited and continue indexing.""" # Arrange dataset, documents = self._create_test_dataset_and_documents(db_session_with_containers, document_count=3) @@ -689,7 +704,7 @@ class TestDatasetIndexingTaskIntegration: self._assert_documents_parsing(db_session_with_containers, document_ids) def test_negative_vector_space_values_handled_gracefully( - self, db_session_with_containers, patched_external_dependencies + self, db_session_with_containers: Session, patched_external_dependencies ): """Treat negative vector limits as non-blocking and continue indexing.""" # Arrange @@ -708,7 +723,7 @@ class TestDatasetIndexingTaskIntegration: patched_external_dependencies["indexing_runner_instance"].run.assert_called_once() self._assert_documents_parsing(db_session_with_containers, document_ids) - def test_large_document_batch_processing(self, db_session_with_containers, patched_external_dependencies): + def test_large_document_batch_processing(self, db_session_with_containers: Session, patched_external_dependencies): """Process a batch exactly at configured upload limit. This test patches config only to force a deterministic limit branch while keeping SQL writes real. diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 48fec441c5..e4cbb9e589 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -12,6 +12,7 @@ from unittest.mock import ANY, Mock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType from models.dataset import Dataset, Document, DocumentSegment @@ -55,7 +56,7 @@ class TestDealDatasetVectorIndexTask: yield mock_factory @pytest.fixture - def account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """Create an account with an owner tenant for testing. Returns a tuple of (account, tenant) where tenant is guaranteed to be non-None. @@ -73,7 +74,7 @@ class TestDealDatasetVectorIndexTask: return account, tenant def test_deal_dataset_vector_index_task_remove_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful removal of dataset vector index. @@ -131,7 +132,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.clean.call_count >= 0 # For now, just check it doesn't fail def test_deal_dataset_vector_index_task_add_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful addition of dataset vector index. @@ -233,7 +234,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_update_action_success( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test successful update of dataset vector index. @@ -337,7 +338,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_dataset_not_found_error( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior when dataset is not found. @@ -357,7 +358,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when no documents exist for the dataset. @@ -389,7 +390,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_no_segments( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action when documents exist but have no segments. @@ -447,7 +448,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_update_action_no_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test update action when no documents exist for the dataset. @@ -480,7 +481,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_not_called() def test_deal_dataset_vector_index_task_add_action_with_exception_handling( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test add action with exception handling during processing. @@ -578,7 +579,7 @@ class TestDealDatasetVectorIndexTask: assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with custom index type (QA_INDEX). @@ -656,7 +657,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_default_index_type( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with default index type (PARAGRAPH_INDEX). @@ -734,7 +735,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_multiple_documents_processing( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task processing with multiple documents and segments. @@ -839,7 +840,7 @@ class TestDealDatasetVectorIndexTask: assert mock_processor.load.call_count == 3 def test_deal_dataset_vector_index_task_document_status_transitions( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test document status transitions during task execution. @@ -938,7 +939,7 @@ class TestDealDatasetVectorIndexTask: assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with disabled documents. @@ -1061,7 +1062,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_archived_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with archived documents. @@ -1184,7 +1185,7 @@ class TestDealDatasetVectorIndexTask: mock_processor.load.assert_called_once() def test_deal_dataset_vector_index_task_with_incomplete_documents( - self, db_session_with_containers, mock_index_processor_factory, account_and_tenant + self, db_session_with_containers: Session, mock_index_processor_factory, account_and_tenant ): """ Test task behavior with documents that have incomplete indexing status. diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index 8a69707b38..f4a71040c1 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -11,9 +11,19 @@ import logging from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, Document, DocumentSegment, Tenant +from models import ( + Account, + AccountStatus, + Dataset, + DatasetPermissionEnum, + Document, + DocumentSegment, + Tenant, + TenantStatus, +) from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task @@ -37,7 +47,7 @@ class TestDeleteSegmentFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_tenant(self, db_session_with_containers, fake=None): + def _create_test_tenant(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant with realistic data. @@ -49,7 +59,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status=TenantStatus.NORMAL) tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -58,7 +68,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return tenant - def _create_test_account(self, db_session_with_containers, tenant, fake=None): + def _create_test_account(self, db_session_with_containers: Session, tenant, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -75,7 +85,7 @@ class TestDeleteSegmentFromIndexTask: name=fake.name(), email=fake.email(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) account.id = fake.uuid4() @@ -86,7 +96,9 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return account - def _create_test_dataset(self, db_session_with_containers, tenant, account, fake=None): + def _create_test_dataset( + self, db_session_with_containers: Session, tenant: Tenant, account: Account, fake: Faker | None = None + ): """ Helper method to create a test dataset with realistic data. @@ -106,7 +118,7 @@ class TestDeleteSegmentFromIndexTask: dataset.name = f"Test Dataset {fake.word()}" dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" - dataset.permission = "only_me" + dataset.permission = DatasetPermissionEnum.ONLY_ME dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = IndexTechniqueType.HIGH_QUALITY dataset.index_struct = '{"type": "paragraph"}' @@ -122,7 +134,7 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None, **kwargs): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None, **kwargs): """ Helper method to create a test document with realistic data. @@ -172,7 +184,14 @@ class TestDeleteSegmentFromIndexTask: db_session_with_containers.commit() return document - def _create_test_document_segments(self, db_session_with_containers, document, account, count=3, fake=None): + def _create_test_document_segments( + self, + db_session_with_containers: Session, + document: Document, + account: Account, + count: int = 3, + fake: Faker | None = None, + ): """ Helper method to create test document segments with realistic data. @@ -218,7 +237,9 @@ class TestDeleteSegmentFromIndexTask: return segments @patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True) - def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers): + def test_delete_segment_from_index_task_success( + self, mock_index_processor_factory, db_session_with_containers: Session + ): """ Test successful segment deletion from index with comprehensive verification. @@ -267,7 +288,7 @@ class TestDeleteSegmentFromIndexTask: assert call_args[1]["with_keywords"] is True assert call_args[1]["delete_child_chunks"] is True - def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_dataset_not_found(self, db_session_with_containers: Session): """ Test task behavior when dataset is not found. @@ -288,7 +309,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when dataset not found - def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_found(self, db_session_with_containers: Session): """ Test task behavior when document is not found. @@ -314,7 +335,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document not found - def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_disabled(self, db_session_with_containers: Session): """ Test task behavior when document is disabled. @@ -342,7 +363,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is disabled - def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_archived(self, db_session_with_containers: Session): """ Test task behavior when document is archived. @@ -370,7 +391,7 @@ class TestDeleteSegmentFromIndexTask: # Verify the task completed without exceptions assert result is None # Task should return None when document is archived - def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers): + def test_delete_segment_from_index_task_document_not_completed(self, db_session_with_containers: Session): """ Test task behavior when document indexing is not completed. diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index 6e03bd9351..6bfb1e1f1e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -13,7 +13,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Dataset, DocumentSegment +from models import Account, AccountStatus, Dataset, DocumentSegment, TenantAccountRole, TenantStatus from models import Document as DatasetDocument from models.dataset import DatasetProcessRule from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus @@ -35,7 +35,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers: Session, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account with realistic data. @@ -51,24 +51,23 @@ class TestDisableSegmentsFromIndexTask: email=fake.email(), name=fake.name(), avatar=fake.url(), - status="active", + status=AccountStatus.ACTIVE, interface_language="en-US", ) - account.id = fake.uuid4() # monkey-patch attributes for test setup + account.updated_at = fake.date_time_this_year() + account.created_at = fake.date_time_this_year() + account.role = TenantAccountRole.OWNER + account.id = fake.uuid4() account.tenant_id = fake.uuid4() account.type = "normal" - account.role = "owner" - account.created_at = fake.date_time_this_year() - account.updated_at = account.created_at - # Create a tenant for the account from models.account import Tenant tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="normal", + status=TenantStatus.NORMAL, ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() @@ -83,7 +82,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake: Faker | None = None): """ Helper method to create a test dataset with realistic data. @@ -117,7 +116,9 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): + def _create_test_document( + self, db_session_with_containers: Session, dataset, account: Account, fake: Faker | None = None + ): """ Helper method to create a test document with realistic data. @@ -216,7 +217,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake: Faker | None = None): """ Helper method to create a dataset process rule. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index b6e7e6e5c9..77cd259833 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -13,6 +13,7 @@ from uuid import uuid4 import pytest from sqlalchemy import delete, func, select, update +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -162,7 +163,7 @@ class TestDocumentIndexingSyncTask: "indexing_runner": indexing_runner, } - def _create_notion_sync_context(self, db_session_with_containers, *, data_source_info: dict | None = None): + def _create_notion_sync_context(self, db_session_with_containers: Session, *, data_source_info: dict | None = None): account, tenant = DocumentIndexingSyncTaskTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DocumentIndexingSyncTaskTestDataFactory.create_dataset( db_session_with_containers, @@ -206,7 +207,7 @@ class TestDocumentIndexingSyncTask: "notion_info": notion_info, } - def test_document_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task handles missing document gracefully.""" # Arrange dataset_id = str(uuid4()) @@ -219,7 +220,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["datasource_service"].get_datasource_credentials.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_missing_notion_workspace_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_workspace_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_workspace_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -235,7 +236,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_missing_notion_page_id(self, db_session_with_containers, mock_external_dependencies): + def test_missing_notion_page_id(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when notion_page_id is missing.""" # Arrange context = self._create_notion_sync_context( @@ -251,7 +252,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_empty_data_source_info(self, db_session_with_containers, mock_external_dependencies): + def test_empty_data_source_info(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task raises error when data_source_info is empty.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers, data_source_info=None) @@ -264,7 +265,7 @@ class TestDocumentIndexingSyncTask: with pytest.raises(ValueError, match="no notion page found"): document_indexing_sync_task(context["dataset"].id, context["document"].id) - def test_credential_not_found(self, db_session_with_containers, mock_external_dependencies): + def test_credential_not_found(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task sets document error state when credential is missing.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -284,7 +285,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_page_not_updated(self, db_session_with_containers, mock_external_dependencies): + def test_page_not_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task exits early when notion page is unchanged.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -310,7 +311,7 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_not_called() - def test_successful_sync_when_page_updated(self, db_session_with_containers, mock_external_dependencies): + def test_successful_sync_when_page_updated(self, db_session_with_containers: Session, mock_external_dependencies): """Test full successful sync flow with SQL state updates and side effects.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -349,7 +350,7 @@ class TestDocumentIndexingSyncTask: assert len(run_documents) == 1 assert getattr(run_documents[0], "id", None) == context["document"].id - def test_dataset_not_found_during_cleaning(self, db_session_with_containers, mock_external_dependencies): + def test_dataset_not_found_during_cleaning(self, db_session_with_containers: Session, mock_external_dependencies): """Test that task still updates document and reindexes if dataset vanishes before clean.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -376,7 +377,9 @@ class TestDocumentIndexingSyncTask: mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_cleaning_error_continues_to_indexing(self, db_session_with_containers, mock_external_dependencies): + def test_cleaning_error_continues_to_indexing( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that indexing continues when index cleanup fails.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -400,7 +403,9 @@ class TestDocumentIndexingSyncTask: assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() - def test_indexing_runner_document_paused_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_document_paused_error( + self, db_session_with_containers: Session, mock_external_dependencies + ): """Test that DocumentIsPausedError does not flip document into error state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) @@ -418,7 +423,7 @@ class TestDocumentIndexingSyncTask: assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None - def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): + def test_indexing_runner_general_error(self, db_session_with_containers: Session, mock_external_dependencies): """Test that indexing errors are persisted to document state.""" # Arrange context = self._create_notion_sync_context(db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index cf1a8666f3..6c1454b6d8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.document_task import DocumentTask from core.rag.index_processor.constant.index_type import IndexTechniqueType from enums.cloud_plan import CloudPlan -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( @@ -51,7 +52,7 @@ class TestDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -71,14 +72,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -133,7 +134,7 @@ class TestDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -153,14 +154,14 @@ class TestDocumentIndexingTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -221,7 +222,9 @@ class TestDocumentIndexingTasks: return dataset, documents - def test_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with multiple documents. @@ -262,7 +265,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 3 def test_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -286,7 +289,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -332,7 +335,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def test_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -373,7 +376,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test processing documents with mixed initial states. @@ -456,7 +459,7 @@ class TestDocumentIndexingTasks: assert len(processed_documents) == 4 def test_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -518,7 +521,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner"].assert_not_called() def test_document_indexing_task_billing_disabled_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful processing when billing is disabled. @@ -554,7 +557,7 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of DocumentIsPausedError from IndexingRunner. @@ -597,7 +600,9 @@ class TestDocumentIndexingTasks: assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== - def test_old_document_indexing_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_old_document_indexing_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test document_indexing_task basic functionality. @@ -619,7 +624,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_normal_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_document_indexing_task basic functionality. @@ -643,7 +648,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_priority_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_document_indexing_task basic functionality. @@ -667,7 +672,7 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_document_indexing_with_tenant_queue_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with no waiting tasks. @@ -717,7 +722,7 @@ class TestDocumentIndexingTasks: mock_task_func.delay.assert_not_called() def test_document_indexing_with_tenant_queue_with_waiting_tasks( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test _document_indexing_with_tenant_queue function with waiting tasks in queue using real Redis. @@ -776,7 +781,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 1 def test_document_indexing_with_tenant_queue_error_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling in _document_indexing_with_tenant_queue using real Redis. @@ -848,7 +853,7 @@ class TestDocumentIndexingTasks: assert len(remaining_tasks) == 0 def test_document_indexing_with_tenant_queue_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation in _document_indexing_with_tenant_queue using real Redis. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index a9a8c0f30c..208fc1aa1d 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -3,9 +3,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import func, select +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Dataset, Document, DocumentSegment from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -33,7 +34,7 @@ class TestDocumentIndexingUpdateTask: "runner_instance": runner_instance, } - def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2): + def _create_dataset_document_with_segments(self, db_session_with_containers: Session, *, segment_count: int = 2): fake = Faker() # Account and tenant @@ -41,12 +42,12 @@ class TestDocumentIndexingUpdateTask: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() - tenant = Tenant(name=fake.company(), status="normal") + tenant = Tenant(name=fake.company(), status=TenantStatus.NORMAL) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -114,7 +115,7 @@ class TestDocumentIndexingUpdateTask: return dataset, document, node_ids - def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies): + def test_cleans_segments_and_reindexes(self, db_session_with_containers: Session, mock_external_dependencies): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Act @@ -153,7 +154,9 @@ class TestDocumentIndexingUpdateTask: first = run_docs[0] assert getattr(first, "id", None) == document.id - def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies): + def test_clean_error_is_logged_and_indexing_continues( + self, db_session_with_containers: Session, mock_external_dependencies + ): dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) # Force clean to raise; task should continue to indexing @@ -173,7 +176,7 @@ class TestDocumentIndexingUpdateTask: ) assert remaining > 0 - def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies): + def test_document_not_found_noop(self, db_session_with_containers: Session, mock_external_dependencies): fake = Faker() # Act with non-existent document id document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index 39c58987fd..12440f3e6b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.indexing_runner import DocumentIsPausedError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType @@ -62,7 +63,7 @@ class TestDuplicateDocumentIndexingTasks: } def _create_test_dataset_and_documents( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, document_count=3 ): """ Helper method to create a test dataset and documents for testing. @@ -145,7 +146,11 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _create_test_dataset_with_segments( - self, db_session_with_containers, mock_external_service_dependencies, document_count=3, segments_per_doc=2 + self, + db_session_with_containers: Session, + mock_external_service_dependencies, + document_count=3, + segments_per_doc=2, ): """ Helper method to create a test dataset with documents and segments. @@ -197,7 +202,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents, segments def _create_test_dataset_with_billing_features( - self, db_session_with_containers, mock_external_service_dependencies, billing_enabled=True + self, db_session_with_containers: Session, mock_external_service_dependencies, billing_enabled=True ): """ Helper method to create a test dataset with billing features configured. @@ -287,7 +292,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents def _test_duplicate_document_indexing_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful duplicate document indexing with multiple documents. @@ -329,7 +334,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 3 def _test_duplicate_document_indexing_task_with_segment_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test duplicate document indexing with existing segments that need cleanup. @@ -379,7 +384,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def _test_duplicate_document_indexing_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -404,7 +409,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["index_processor"].clean.assert_not_called() def test_duplicate_document_indexing_task_document_not_found_in_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when some documents don't exist in the dataset. @@ -450,7 +455,7 @@ class TestDuplicateDocumentIndexingTasks: assert len(processed_documents) == 2 # Only existing documents def _test_duplicate_document_indexing_task_indexing_runner_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of IndexingRunner exceptions. @@ -491,7 +496,7 @@ class TestDuplicateDocumentIndexingTasks: assert updated_document.processing_started_at is not None def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for sandbox plan batch upload limit. @@ -554,7 +559,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test billing validation for vector space limit. @@ -596,7 +601,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() def test_duplicate_document_indexing_task_with_empty_document_list( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of empty document list. @@ -622,7 +627,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once_with([]) def test_deprecated_duplicate_document_indexing_task_delegates_to_core( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that deprecated duplicate_document_indexing_task delegates to core function. @@ -655,7 +660,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test normal_duplicate_document_indexing_task with tenant isolation queue. @@ -698,7 +703,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test priority_duplicate_document_indexing_task with tenant isolation queue. @@ -742,7 +747,7 @@ class TestDuplicateDocumentIndexingTasks: @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( - self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies + self, mock_queue_class, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant queue wrapper processes next queued tasks. @@ -789,7 +794,7 @@ class TestDuplicateDocumentIndexingTasks: mock_queue.delete_task_key.assert_not_called() def test_successful_duplicate_document_indexing( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test successful duplicate document indexing flow.""" self._test_duplicate_document_indexing_task_success( @@ -797,7 +802,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when dataset is not found.""" self._test_duplicate_document_indexing_task_dataset_not_found( @@ -805,7 +810,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing with billing enabled and sandbox plan.""" self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( @@ -813,7 +818,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_with_billing_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when billing limit is exceeded.""" self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( @@ -821,7 +826,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_runner_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when IndexingRunner raises an error.""" self._test_duplicate_document_indexing_task_indexing_runner_exception( @@ -829,7 +834,7 @@ class TestDuplicateDocumentIndexingTasks: ) def _test_duplicate_document_indexing_task_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" # Arrange @@ -860,7 +865,7 @@ class TestDuplicateDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() def test_duplicate_document_indexing_document_is_paused( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test duplicate document indexing when document is paused.""" self._test_duplicate_document_indexing_task_document_is_paused( @@ -868,7 +873,7 @@ class TestDuplicateDocumentIndexingTasks: ) def test_duplicate_document_indexing_cleans_old_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """Test that duplicate document indexing cleans old segments.""" self._test_duplicate_document_indexing_task_with_segment_cleanup( diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py index 177af266fb..a697878bb6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_change_mail_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -29,7 +30,7 @@ class TestMailChangeMailTask: "get_email_i18n_service": mock_get_email_i18n_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -72,7 +73,7 @@ class TestMailChangeMailTask: return account def test_send_change_mail_task_success_old_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for old_email phase. @@ -103,7 +104,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_success_new_email_phase( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email task execution for new_email phase. @@ -134,7 +135,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when mail service is not initialized. @@ -159,7 +160,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_change_email.assert_not_called() def test_send_change_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email task when email service raises an exception. @@ -191,7 +192,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful change email completed notification task execution. @@ -224,7 +225,7 @@ class TestMailChangeMailTask: ) def test_send_change_mail_completed_notification_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when mail service is not initialized. @@ -247,7 +248,7 @@ class TestMailChangeMailTask: mock_external_service_dependencies["email_i18n_service"].send_email.assert_not_called() def test_send_change_mail_completed_notification_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test change email completed notification task when email service raises an exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 8343711998..8e9da6aaaa 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -15,6 +15,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import delete +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -37,7 +38,7 @@ class TestSendEmailCodeLoginMailTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -71,7 +72,7 @@ class TestSendEmailCodeLoginMailTask: "email_service_instance": mock_email_service_instance, } - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test account for testing. @@ -98,7 +99,7 @@ class TestSendEmailCodeLoginMailTask: return account - def _create_test_tenant_and_account(self, db_session_with_containers, fake=None): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, fake: Faker | None = None): """ Helper method to create a test tenant and account for testing. @@ -138,7 +139,7 @@ class TestSendEmailCodeLoginMailTask: return account, tenant def test_send_email_code_login_mail_task_success_english( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in English. @@ -182,7 +183,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_chinese( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending in Chinese. @@ -221,7 +222,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_success_multiple_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful email code login mail sending with multiple languages. @@ -261,7 +262,7 @@ class TestSendEmailCodeLoginMailTask: assert call_args[1]["template_context"]["code"] == test_codes[i] def test_send_email_code_login_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when mail service is not initialized. @@ -299,7 +300,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_not_called() def test_send_email_code_login_mail_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task when email service raises an exception. @@ -346,7 +347,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_invalid_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with invalid parameters. @@ -388,7 +389,7 @@ class TestSendEmailCodeLoginMailTask: mock_email_service_instance.send_email.assert_called_once() def test_send_email_code_login_mail_task_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with edge cases and boundary conditions. @@ -451,7 +452,7 @@ class TestSendEmailCodeLoginMailTask: ) def test_send_email_code_login_mail_task_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with database integration. @@ -497,7 +498,7 @@ class TestSendEmailCodeLoginMailTask: assert account.status == "active" def test_send_email_code_login_mail_task_redis_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email code login mail task with Redis integration. @@ -541,7 +542,7 @@ class TestSendEmailCodeLoginMailTask: redis_client.delete(cache_key) def test_send_email_code_login_mail_task_error_handling_comprehensive( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error handling for email code login mail task. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 95a867dbb5..f505361727 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from sqlalchemy import delete +from sqlalchemy.orm import Session from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig @@ -172,7 +173,9 @@ def _create_workflow_pause_state( db_session_with_containers.commit() -def test_dispatch_human_input_email_task_integration(monkeypatch: pytest.MonkeyPatch, db_session_with_containers): +def test_dispatch_human_input_email_task_integration( + monkeypatch: pytest.MonkeyPatch, db_session_with_containers: Session +): tenant, account = _create_workspace_member(db_session_with_containers) workflow_run_id = str(uuid.uuid4()) workflow_id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py index 1a20b6deec..f8e54ea9e6 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_inner_task.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from tasks.mail_inner_task import send_inner_email_task @@ -51,7 +52,7 @@ class TestMailInnerTask: }, } - def test_send_inner_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful email sending with valid data. @@ -90,7 +91,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_single_recipient(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_single_recipient( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with single recipient. @@ -126,7 +129,9 @@ class TestMailInnerTask: html_content="Test email content", ) - def test_send_inner_email_empty_substitutions(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_empty_substitutions( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending with empty substitutions. @@ -163,7 +168,7 @@ class TestMailInnerTask: ) def test_send_inner_email_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when mail service is not initialized. @@ -193,7 +198,7 @@ class TestMailInnerTask: mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() def test_send_inner_email_template_rendering_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending when template rendering fails. @@ -222,7 +227,9 @@ class TestMailInnerTask: # Verify no email service calls due to exception mock_external_service_dependencies["email_service"].send_raw_email.assert_not_called() - def test_send_inner_email_service_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_inner_email_service_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email sending when email service fails. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py index d34828c4b1..c8c7a4d961 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_invite_member_task.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import delete, select +from sqlalchemy.orm import Session from extensions.ext_redis import redis_client from libs.email_i18n import EmailType @@ -42,7 +43,7 @@ class TestMailInviteMemberTask: """ @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" # Clear all test data db_session_with_containers.execute(delete(TenantAccountJoin)) @@ -78,7 +79,7 @@ class TestMailInviteMemberTask: "config": mock_config, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -147,7 +148,7 @@ class TestMailInviteMemberTask: redis_client.setex(cache_key, 24 * 60 * 60, json.dumps(invitation_data)) # 24 hours return token - def _create_pending_account_for_invitation(self, db_session_with_containers, email, tenant): + def _create_pending_account_for_invitation(self, db_session_with_containers: Session, email, tenant): """ Helper method to create a pending account for invitation testing. @@ -185,7 +186,9 @@ class TestMailInviteMemberTask: return account - def test_send_invite_member_mail_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_invite_member_mail_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful invitation email sending with all parameters. @@ -231,7 +234,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_different_languages( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test invitation email sending with different language codes. @@ -263,7 +266,7 @@ class TestMailInviteMemberTask: assert call_args[1]["language_code"] == language def test_send_invite_member_mail_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test behavior when mail service is not initialized. @@ -292,7 +295,7 @@ class TestMailInviteMemberTask: mock_email_service.send_email.assert_not_called() def test_send_invite_member_mail_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when email service raises an exception. @@ -322,7 +325,7 @@ class TestMailInviteMemberTask: assert "Send invite member mail to %s failed" in error_call def test_send_invite_member_mail_template_context_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test template context contains all required fields for email rendering. @@ -368,7 +371,7 @@ class TestMailInviteMemberTask: assert template_context["url"] == f"https://console.dify.ai/activate?token={token}" def test_send_invite_member_mail_integration_with_redis_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test integration with Redis token validation. @@ -407,7 +410,7 @@ class TestMailInviteMemberTask: assert invitation_data["workspace_id"] == tenant.id def test_send_invite_member_mail_with_special_characters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test email sending with special characters in names and workspace names. @@ -449,7 +452,7 @@ class TestMailInviteMemberTask: assert template_context["workspace_name"] == workspace_name def test_send_invite_member_mail_real_database_integration( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test real database integration with actual invitation flow. @@ -501,7 +504,7 @@ class TestMailInviteMemberTask: assert tenant_join.role == TenantAccountRole.NORMAL def test_send_invite_member_mail_token_lifecycle_management( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test token lifecycle management and validation. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py index e08b099480..176645a4ab 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_owner_transfer_task.py @@ -11,6 +11,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -44,7 +45,7 @@ class TestMailOwnerTransferTask: "get_email_service": mock_get_email_service, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create test account and tenant for testing. @@ -86,7 +87,9 @@ class TestMailOwnerTransferTask: return account, tenant - def test_send_owner_transfer_confirm_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_owner_transfer_confirm_task_success( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """ Test successful owner transfer confirmation email sending. @@ -127,7 +130,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_owner_transfer_confirm_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test owner transfer confirmation email when mail service is not initialized. @@ -158,7 +161,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_owner_transfer_confirm_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in owner transfer confirmation email. @@ -192,7 +195,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_old_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful old owner transfer notification email sending. @@ -234,7 +237,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["NewOwnerEmail"] == test_new_owner_email def test_send_old_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test old owner transfer notification email when mail service is not initialized. @@ -265,7 +268,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_old_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in old owner transfer notification email. @@ -299,7 +302,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_called_once() def test_send_new_owner_transfer_notify_email_task_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test successful new owner transfer notification email sending. @@ -338,7 +341,7 @@ class TestMailOwnerTransferTask: assert call_args[1]["template_context"]["WorkspaceName"] == test_workspace def test_send_new_owner_transfer_notify_email_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test new owner transfer notification email when mail service is not initialized. @@ -367,7 +370,7 @@ class TestMailOwnerTransferTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_new_owner_transfer_notify_email_task_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """ Test exception handling in new owner transfer notification email. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py index cced6f7780..071971f324 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -9,6 +9,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from libs.email_i18n import EmailType from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist @@ -35,7 +36,7 @@ class TestMailRegisterTask: "get_email_service": mock_get_email_service, } - def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_success(self, db_session_with_containers: Session, mock_mail_dependencies): """Test successful email registration mail sending.""" fake = Faker() language = "en-US" @@ -56,7 +57,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test email registration task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -66,7 +67,9 @@ class TestMailRegisterTask: mock_mail_dependencies["get_email_service"].assert_not_called() mock_mail_dependencies["email_service"].send_email.assert_not_called() - def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies): + def test_send_email_register_mail_task_exception_handling( + self, db_session_with_containers: Session, mock_mail_dependencies + ): """Test email registration task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") @@ -79,7 +82,7 @@ class TestMailRegisterTask: mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) def test_send_email_register_mail_task_when_account_exist_success( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test successful email registration mail sending when account exists.""" fake = Faker() @@ -105,7 +108,7 @@ class TestMailRegisterTask: ) def test_send_email_register_mail_task_when_account_exist_mail_not_initialized( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task when mail service is not initialized.""" mock_mail_dependencies["mail"].is_inited.return_value = False @@ -118,7 +121,7 @@ class TestMailRegisterTask: mock_mail_dependencies["email_service"].send_email.assert_not_called() def test_send_email_register_mail_task_when_account_exist_exception_handling( - self, db_session_with_containers, mock_mail_dependencies + self, db_session_with_containers: Session, mock_mail_dependencies ): """Test account exist email task exception handling.""" mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index f01fcc1742..5eea985fdc 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -4,12 +4,13 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from flask import Flask from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus from models.dataset import Pipeline from models.workflow import Workflow from tasks.rag_pipeline.priority_rag_pipeline_run_task import ( @@ -69,14 +70,14 @@ class TestRagPipelineRunTasks: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -725,7 +726,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test successful run_single_rag_pipeline_task execution. @@ -760,7 +761,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -805,7 +806,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers: Flask ): """ Test run_single_rag_pipeline_task with non-existent database entities. diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index b43b622870..03c02ea341 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -3,6 +3,7 @@ from unittest.mock import ANY, call, patch import pytest from sqlalchemy import delete, func, select +from sqlalchemy.orm import Session from core.db.session_factory import session_factory from extensions.storage.storage_type import StorageType @@ -117,7 +118,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: class TestDeleteDraftVariablesBatch: - def test_delete_draft_variables_batch_success(self, db_session_with_containers): + def test_delete_draft_variables_batch_success(self, db_session_with_containers: Session): """Test successful deletion of draft variables in batches.""" _, app1 = _create_tenant_and_app(db_session_with_containers) _, app2 = _create_tenant_and_app(db_session_with_containers) @@ -137,7 +138,7 @@ class TestDeleteDraftVariablesBatch: assert app1_remaining_count == 0 assert app2_remaining_count == 100 - def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers): + def test_delete_draft_variables_batch_empty_result(self, db_session_with_containers: Session): """Test deletion when no draft variables exist for the app.""" result = delete_draft_variables_batch(str(uuid.uuid4()), 1000) @@ -176,7 +177,7 @@ class TestDeleteDraftVariableOffloadData: """Test the Offload data cleanup functionality.""" @patch("extensions.ext_storage.storage") - def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers): + def test_delete_draft_variable_offload_data_success(self, mock_storage, db_session_with_containers: Session): """Test successful deletion of offload data.""" tenant, app = _create_tenant_and_app(db_session_with_containers) offload_data = _create_offload_data(db_session_with_containers, tenant_id=tenant.id, app_id=app.id, count=3) diff --git a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py index 34a1941c39..6365207661 100644 --- a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py +++ b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py @@ -1,12 +1,14 @@ from pathlib import Path +import pytest + from extensions.storage.opendal_storage import OpenDALStorage class TestOpenDALFsDefaultRoot: """Test that OpenDALStorage with scheme='fs' works correctly when no root is provided.""" - def test_fs_without_root_uses_default(self, tmp_path, monkeypatch): + def test_fs_without_root_uses_default(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When no root is specified, the default 'storage' should be used and passed to the Operator.""" # Change to tmp_path so the default "storage" dir is created there monkeypatch.chdir(tmp_path) @@ -25,7 +27,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_default_root.txt") - def test_fs_with_explicit_root(self, tmp_path): + def test_fs_with_explicit_root(self, tmp_path: Path): """When root is explicitly provided, it should be used.""" custom_root = str(tmp_path / "custom_storage") storage = OpenDALStorage(scheme="fs", root=custom_root) @@ -38,7 +40,7 @@ class TestOpenDALFsDefaultRoot: # Cleanup storage.delete("test_explicit_root.txt") - def test_fs_with_env_var_root(self, tmp_path, monkeypatch): + def test_fs_with_env_var_root(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch): """When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs.""" env_root = str(tmp_path / "env_storage") monkeypatch.setenv("OPENDAL_FS_ROOT", env_root) diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index b00d827e37..6402e7da2b 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -175,7 +175,7 @@ class TestWorkflowPauseIntegration: """Comprehensive integration tests for workflow pause functionality.""" @pytest.fixture(autouse=True) - def setup_test_data(self, db_session_with_containers): + def setup_test_data(self, db_session_with_containers: Session): """Set up test data for each test method using TestContainers.""" # Create test tenant and account diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py index 19a41b6186..a5086b4c5d 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_javascript.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): """Test class for JavaScript code executor functionality.""" - def test_javascript_plain(self, flask_app_with_containers): + def test_javascript_plain(self, flask_app_with_containers: Flask): """Test basic JavaScript code execution with console.log output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result_message = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result_message == "Hello World\n" - def test_javascript_json(self, flask_app_with_containers): + def test_javascript_json(self, flask_app_with_containers: Flask): """Test JavaScript code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.JAVASCRIPT, preload="", code=code) assert result == '{"Hello":"World"}\n' - def test_javascript_with_code_template(self, flask_app_with_containers): + def test_javascript_with_code_template(self, flask_app_with_containers: Flask): """Test JavaScript workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports JavascriptCodeProvider, _ = self.javascript_imports @@ -37,7 +39,7 @@ class TestJavaScriptCodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_javascript_get_runner_script(self, flask_app_with_containers): + def test_javascript_get_runner_script(self, flask_app_with_containers: Flask): """Test JavaScript template transformer runner script generation""" _, NodeJsTemplateTransformer = self.javascript_imports diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py index ddb079f00c..8b4c3c3d4a 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_jinja2.py @@ -1,12 +1,14 @@ import base64 +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestJinja2CodeExecutor(CodeExecutorTestMixin): """Test class for Jinja2 code executor functionality.""" - def test_jinja2(self, flask_app_with_containers): + def test_jinja2(self, flask_app_with_containers: Flask): """Test basic Jinja2 template execution with variable substitution""" CodeExecutor, CodeLanguage = self.code_executor_imports _, Jinja2TemplateTransformer = self.jinja2_imports @@ -25,7 +27,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == "<>Hello World<>\n" - def test_jinja2_with_code_template(self, flask_app_with_containers): + def test_jinja2_with_code_template(self, flask_app_with_containers: Flask): """Test Jinja2 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -34,7 +36,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "Hello World"} - def test_jinja2_get_runner_script(self, flask_app_with_containers): + def test_jinja2_get_runner_script(self, flask_app_with_containers: Flask): """Test Jinja2 template transformer runner script generation""" _, Jinja2TemplateTransformer = self.jinja2_imports @@ -43,7 +45,7 @@ class TestJinja2CodeExecutor(CodeExecutorTestMixin): assert runner_script.count(Jinja2TemplateTransformer._inputs_placeholder) == 1 assert runner_script.count(Jinja2TemplateTransformer._result_tag) == 2 - def test_jinja2_template_with_special_characters(self, flask_app_with_containers): + def test_jinja2_template_with_special_characters(self, flask_app_with_containers: Flask): """ Test that templates with special characters (quotes, newlines) render correctly. This is a regression test for issue #26818 where textarea pre-fill values diff --git a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py index 6d93df2472..0de41e1312 100644 --- a/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ b/api/tests/test_containers_integration_tests/workflow/nodes/code_executor/test_code_python3.py @@ -1,12 +1,14 @@ from textwrap import dedent +from flask import Flask + from .test_utils import CodeExecutorTestMixin class TestPython3CodeExecutor(CodeExecutorTestMixin): """Test class for Python3 code executor functionality.""" - def test_python3_plain(self, flask_app_with_containers): + def test_python3_plain(self, flask_app_with_containers: Flask): """Test basic Python3 code execution with print output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -14,7 +16,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == "Hello World\n" - def test_python3_json(self, flask_app_with_containers): + def test_python3_json(self, flask_app_with_containers: Flask): """Test Python3 code execution with JSON output""" CodeExecutor, CodeLanguage = self.code_executor_imports @@ -25,7 +27,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): result = CodeExecutor.execute_code(language=CodeLanguage.PYTHON3, preload="", code=code) assert result == '{"Hello": "World"}\n' - def test_python3_with_code_template(self, flask_app_with_containers): + def test_python3_with_code_template(self, flask_app_with_containers: Flask): """Test Python3 workflow code template execution with inputs""" CodeExecutor, CodeLanguage = self.code_executor_imports Python3CodeProvider, _ = self.python3_imports @@ -37,7 +39,7 @@ class TestPython3CodeExecutor(CodeExecutorTestMixin): ) assert result == {"result": "HelloWorld"} - def test_python3_get_runner_script(self, flask_app_with_containers): + def test_python3_get_runner_script(self, flask_app_with_containers: Flask): """Test Python3 template transformer runner script generation""" _, Python3TemplateTransformer = self.python3_imports diff --git a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py index d3e864a75a..78413a0798 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_account_activation.py +++ b/api/tests/unit_tests/controllers/console/auth/test_account_activation.py @@ -67,7 +67,7 @@ class TestActivateCheckApi: assert response["data"]["email"] == "invitee@example.com" @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_check_invalid_invitation_token(self, mock_get_invitation, app): + def test_check_invalid_invitation_token(self, mock_get_invitation, app: Flask): """ Test checking invalid invitation token. @@ -227,7 +227,7 @@ class TestActivateApi: mock_db.session.commit.assert_called_once() @patch("controllers.console.auth.activate.RegisterService.get_invitation_with_case_fallback") - def test_activation_with_invalid_token(self, mock_get_invitation, app): + def test_activation_with_invalid_token(self, mock_get_invitation, app: Flask): """ Test account activation with invalid token. diff --git a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py index b7bc73da5f..7b2c7569fe 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_email_verification.py +++ b/api/tests/unit_tests/controllers/console/auth/test_email_verification.py @@ -140,7 +140,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") - def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app): + def test_send_email_code_ip_rate_limited(self, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending blocked by IP rate limit. @@ -160,7 +160,7 @@ class TestEmailCodeLoginSendEmailApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.is_email_send_ip_limit") @patch("controllers.console.auth.login.AccountService.get_user_through_email") - def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app): + def test_send_email_code_frozen_account(self, mock_get_user, mock_is_ip_limit, mock_db, app: Flask): """ Test email code sending to frozen account. @@ -353,7 +353,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app): + def test_email_code_login_invalid_token(self, mock_get_data, mock_db, app: Flask): """ Test email code login with invalid token. @@ -375,7 +375,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app): + def test_email_code_login_email_mismatch(self, mock_get_data, mock_db, app: Flask): """ Test email code login with mismatched email. @@ -397,7 +397,7 @@ class TestEmailCodeLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.AccountService.get_email_code_login_data") - def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app): + def test_email_code_login_wrong_code(self, mock_get_data, mock_db, app: Flask): """ Test email code login with incorrect code. diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index d089be8905..5284f29eed 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -9,7 +9,7 @@ This module tests the core authentication endpoints including: """ import base64 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from flask import Flask @@ -52,12 +52,12 @@ class TestLoginApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(LoginApi, "/login") return app.test_client() @@ -97,7 +97,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -141,14 +141,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.AccountService.reset_login_error_rate_limit") def test_successful_login_with_valid_invitation( self, - mock_reset_rate_limit, + mock_reset_rate_limit: Mock, mock_login, mock_get_tenants, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, mock_account, mock_token_pair, ): @@ -188,7 +188,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_fails_when_rate_limited(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login rejection when rate limit is exceeded. @@ -216,7 +216,7 @@ class TestLoginApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", True) @patch("controllers.console.auth.login.BillingService.is_email_in_freeze") - def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app): + def test_login_fails_when_account_frozen(self, mock_is_frozen, mock_db, app: Flask): """ Test login rejection for frozen accounts. @@ -253,7 +253,7 @@ class TestLoginApi: mock_get_invitation, mock_is_rate_limit, mock_db, - app, + app: Flask, ): """ Test login failure with invalid credentials. @@ -290,7 +290,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") @patch("controllers.console.auth.login.AccountService.authenticate") def test_login_fails_for_banned_account( - self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app + self, mock_authenticate, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask ): """ Test login rejection for banned accounts. @@ -328,14 +328,14 @@ class TestLoginApi: @patch("controllers.console.auth.login.FeatureService.get_system_features") def test_login_fails_when_no_workspace_and_limit_exceeded( self, - mock_get_features, - mock_get_tenants, - mock_authenticate, - mock_get_invitation, - mock_is_rate_limit, - mock_db, - app, - mock_account, + mock_get_features: MagicMock, + mock_get_tenants: MagicMock, + mock_authenticate: MagicMock, + mock_get_invitation: MagicMock, + mock_is_rate_limit: MagicMock, + mock_db: MagicMock, + app: Flask, + mock_account: MagicMock, ): """ Test login failure when user has no workspace and workspace limit exceeded. @@ -367,7 +367,7 @@ class TestLoginApi: @patch("controllers.console.auth.login.dify_config.BILLING_ENABLED", False) @patch("controllers.console.auth.login.AccountService.is_login_error_rate_limit") @patch("controllers.console.auth.login.RegisterService.get_invitation_with_case_fallback") - def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app): + def test_login_invitation_email_mismatch(self, mock_get_invitation, mock_is_rate_limit, mock_db, app: Flask): """ Test login failure when invitation email doesn't match login email. @@ -491,7 +491,7 @@ class TestLogoutApi: @patch("controllers.console.auth.login.AccountService.logout") @patch("controllers.console.auth.login.flask_login.logout_user") def test_successful_logout( - self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app, mock_account + self, mock_logout_user, mock_service_logout, mock_current_account, mock_db, app: Flask, mock_account ): """ Test successful logout flow. @@ -518,7 +518,7 @@ class TestLogoutApi: @patch("controllers.console.wraps.db") @patch("controllers.console.auth.login.current_account_with_tenant") @patch("controllers.console.auth.login.flask_login") - def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app): + def test_logout_anonymous_user(self, mock_flask_login, mock_current_account, mock_db, app: Flask): """ Test logout for anonymous (not logged in) user. diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index d010f60866..15c95f6b94 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -28,12 +28,12 @@ class TestRefreshTokenApi: return app @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return Api(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client.""" api.add_resource(RefreshTokenApi, "/refresh-token") return app.test_client() diff --git a/api/tests/unit_tests/controllers/console/billing/test_billing.py b/api/tests/unit_tests/controllers/console/billing/test_billing.py index 810f1b94fc..defa9064fd 100644 --- a/api/tests/unit_tests/controllers/console/billing/test_billing.py +++ b/api/tests/unit_tests/controllers/console/billing/test_billing.py @@ -49,7 +49,7 @@ class TestPartnerTenants: mock_csrf.return_value = None yield {"db": mock_db, "csrf": mock_csrf} - def test_put_success(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_success(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test successful partner tenants bindings sync.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -79,7 +79,7 @@ class TestPartnerTenants: mock_account.id, "partner-key-123", click_id ) - def test_put_invalid_partner_key_base64(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_invalid_partner_key_base64(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that invalid base64 partner_key raises BadRequest.""" # Arrange invalid_partner_key = "invalid-base64-!@#$" @@ -104,7 +104,7 @@ class TestPartnerTenants: resource.put(invalid_partner_key) assert "Invalid partner_key" in str(exc_info.value) - def test_put_missing_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_missing_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that missing click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -128,7 +128,9 @@ class TestPartnerTenants: with pytest.raises(BadRequest): resource.put(partner_key_encoded) - def test_put_billing_service_json_decode_error(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_billing_service_json_decode_error( + self, app: Flask, mock_account, mock_billing_service, mock_decorators + ): """Test handling of billing service JSON decode error. When billing service returns non-200 status code with invalid JSON response, @@ -174,7 +176,7 @@ class TestPartnerTenants: assert isinstance(exc_info.value, json.JSONDecodeError) assert "Expecting value" in str(exc_info.value) - def test_put_empty_click_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_click_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty click_id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") @@ -199,7 +201,7 @@ class TestPartnerTenants: resource.put(partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_partner_key_after_decode(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_partner_key_after_decode(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty partner_key after decode raises BadRequest.""" # Arrange # Base64 encode an empty string @@ -225,7 +227,7 @@ class TestPartnerTenants: resource.put(empty_partner_key_encoded) assert "Invalid partner information" in str(exc_info.value) - def test_put_empty_user_id(self, app, mock_account, mock_billing_service, mock_decorators): + def test_put_empty_user_id(self, app: Flask, mock_account, mock_billing_service, mock_decorators): """Test that empty user id raises BadRequest.""" # Arrange partner_key_encoded = base64.b64encode(b"partner-key-123").decode("utf-8") diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index 6405558bb4..a26d171649 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -8,10 +8,8 @@ from werkzeug.exceptions import Forbidden import controllers.console.tag.tags as module from controllers.console import console_ns from controllers.console.tag.tags import ( - DeprecatedTagBindingCreateApi, - DeprecatedTagBindingRemoveApi, TagBindingCollectionApi, - TagBindingItemApi, + TagBindingRemoveApi, TagListApi, TagUpdateDeleteApi, ) @@ -249,39 +247,13 @@ class TestTagBindingCollectionApi: method(api) -class TestDeprecatedTagBindingCreateApi: - def test_create_success(self, app, admin_user, payload_patch): - api = DeprecatedTagBindingCreateApi() +class TestTagBindingRemoveApi: + def test_remove_success(self, app, admin_user, payload_patch): + api = TagBindingRemoveApi() method = unwrap(api.post) payload = { - "tag_ids": ["tag-1"], - "target_id": "target-1", - "type": "knowledge", - } - - with app.test_request_context("/", json=payload): - with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), - payload_patch(payload), - patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, - ): - result, status = method(api) - - save_mock.assert_called_once() - assert status == 200 - assert result["result"] == "success" - - -class TestTagBindingItemApi: - def test_delete_success(self, app, admin_user, payload_patch): - api = TagBindingItemApi() - method = unwrap(api.delete) - - payload = { + "tag_ids": ["tag-1", "tag-2"], "target_id": "target-1", "type": "knowledge", } @@ -295,57 +267,16 @@ class TestTagBindingItemApi: payload_patch(payload), patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, ): - result, status = method(api, "tag-1") + result, status = method(api) delete_mock.assert_called_once() delete_payload = delete_mock.call_args.args[0] - assert delete_payload.tag_id == "tag-1" - assert delete_payload.target_id == "target-1" - assert delete_payload.type == TagType.KNOWLEDGE - assert status == 200 - assert result["result"] == "success" - - def test_delete_forbidden(self, app, readonly_user): - api = TagBindingItemApi() - method = unwrap(api.delete) - - with app.test_request_context("/"): - with patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ): - with pytest.raises(Forbidden): - method(api, "tag-1") - - -class TestDeprecatedTagBindingRemoveApi: - def test_remove_success(self, app, admin_user, payload_patch): - api = DeprecatedTagBindingRemoveApi() - method = unwrap(api.post) - - payload = { - "tag_id": "tag-1", - "target_id": "target-1", - "type": "knowledge", - } - - with app.test_request_context("/", json=payload): - with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), - payload_patch(payload), - patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, - ): - result, status = method(api) - - delete_mock.assert_called_once() + assert delete_payload.tag_ids == ["tag-1", "tag-2"] assert status == 200 assert result["result"] == "success" def test_remove_forbidden(self, app, readonly_user, payload_patch): - api = DeprecatedTagBindingRemoveApi() + api = TagBindingRemoveApi() method = unwrap(api.post) with app.test_request_context("/", json={}): @@ -371,32 +302,30 @@ class TestTagResponseModel: class TestTagBindingRouteMetadata: - def test_legacy_write_routes_are_marked_deprecated(self): - assert DeprecatedTagBindingCreateApi.post.__apidoc__["deprecated"] is True - assert DeprecatedTagBindingRemoveApi.post.__apidoc__["deprecated"] is True + def test_write_routes_are_not_deprecated(self): assert TagBindingCollectionApi.post.__apidoc__.get("deprecated") is not True - assert TagBindingItemApi.delete.__apidoc__.get("deprecated") is not True + assert TagBindingRemoveApi.post.__apidoc__.get("deprecated") is not True def test_write_routes_have_stable_operation_ids(self): assert TagBindingCollectionApi.post.__apidoc__["id"] == "create_tag_binding" - assert TagBindingItemApi.delete.__apidoc__["id"] == "delete_tag_binding" - assert DeprecatedTagBindingCreateApi.post.__apidoc__["id"] == "create_tag_binding_deprecated" - assert DeprecatedTagBindingRemoveApi.post.__apidoc__["id"] == "delete_tag_binding_deprecated" + assert TagBindingRemoveApi.post.__apidoc__["id"] == "remove_tag_bindings" - def test_canonical_and_legacy_write_routes_are_registered(self): + def test_write_routes_are_registered(self): route_map = { resource.__name__: urls for resource, urls, _route_doc, _kwargs in console_ns.resources if resource.__name__ in { "TagBindingCollectionApi", - "TagBindingItemApi", - "DeprecatedTagBindingCreateApi", - "DeprecatedTagBindingRemoveApi", + "TagBindingRemoveApi", } } assert route_map["TagBindingCollectionApi"] == ("/tag-bindings",) - assert route_map["TagBindingItemApi"] == ("/tag-bindings/",) - assert route_map["DeprecatedTagBindingCreateApi"] == ("/tag-bindings/create",) - assert route_map["DeprecatedTagBindingRemoveApi"] == ("/tag-bindings/remove",) + assert route_map["TagBindingRemoveApi"] == ("/tag-bindings/remove",) + + def test_legacy_write_routes_are_not_registered(self): + urls = {url for _resource, resource_urls, _route_doc, _kwargs in console_ns.resources for url in resource_urls} + + assert "/tag-bindings/create" not in urls + assert "/tag-bindings/" not in urls diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index 032b1377a4..99a90f3b67 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -17,7 +17,7 @@ def app(): return app -def test_parse_openapi_to_tool_bundle_operation_id(app): +def test_parse_openapi_to_tool_bundle_operation_id(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -63,7 +63,7 @@ def test_parse_openapi_to_tool_bundle_operation_id(app): assert tool_bundles[2].operation_id == "createResource" -def test_parse_openapi_to_tool_bundle_properties_all_of(app): +def test_parse_openapi_to_tool_bundle_properties_all_of(app: Flask): openapi = { "openapi": "3.0.0", "info": {"title": "Simple API", "version": "1.0.0"}, @@ -118,7 +118,7 @@ def test_parse_openapi_to_tool_bundle_properties_all_of(app): # assert set(tool_bundles[0].parameters[0].options) == {"option1", "option2", "option3"} -def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): +def test_parse_openapi_to_tool_bundle_default_value_type_casting(app: Flask): """ Test that default values are properly cast to match parameter types. This addresses the issue where array default values like [] cause validation errors diff --git a/api/tests/unit_tests/services/controller_api.py b/api/tests/unit_tests/services/controller_api.py index 762d7b9090..e7f7cabecd 100644 --- a/api/tests/unit_tests/services/controller_api.py +++ b/api/tests/unit_tests/services/controller_api.py @@ -146,7 +146,7 @@ class ControllerApiTestDataFactory: return app @staticmethod - def create_api_instance(app): + def create_api_instance(app: Flask): """ Create a Flask-RESTX API instance. @@ -160,7 +160,12 @@ class ControllerApiTestDataFactory: return api @staticmethod - def create_test_client(app, api, resource_class, route): + def create_test_client( + app: Flask, + api: Api, + resource_class: type, + route: str, + ): """ Create a Flask test client with a resource registered. @@ -302,7 +307,7 @@ class TestDatasetListApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """ Create Flask-RESTX API instance. @@ -311,7 +316,7 @@ class TestDatasetListApi: return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """ Create test client with DatasetListApi registered. @@ -472,12 +477,12 @@ class TestDatasetApiGet: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with DatasetApi registered.""" return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets/") @@ -588,12 +593,12 @@ class TestDatasetApiCreate: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with DatasetApi registered.""" return ControllerApiTestDataFactory.create_test_client(app, api, DatasetApi, "/datasets") @@ -681,12 +686,12 @@ class TestHitTestingApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client(self, app, api): + def client(self, app: Flask, api: Api): """Create test client with HitTestingApi registered.""" return ControllerApiTestDataFactory.create_test_client( app, api, HitTestingApi, "/datasets//hit-testing" @@ -799,12 +804,12 @@ class TestExternalDatasetApi: return ControllerApiTestDataFactory.create_flask_app() @pytest.fixture - def api(self, app): + def api(self, app: Flask): """Create Flask-RESTX API instance.""" return ControllerApiTestDataFactory.create_api_instance(app) @pytest.fixture - def client_list(self, app, api): + def client_list(self, app: Flask, api: Api): """Create test client for external knowledge API list endpoint.""" return ControllerApiTestDataFactory.create_test_client( app, api, ExternalApiTemplateListApi, "/datasets/external-knowledge-api" diff --git a/api/uv.lock b/api/uv.lock index 9806f506aa..6f75c9f6fe 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1629,7 +1629,7 @@ dev = [ { name = "lxml-stubs", specifier = ">=0.5.1" }, { name = "mypy", specifier = ">=1.20.2" }, { name = "pandas-stubs", specifier = ">=3.0.0" }, - { name = "pyrefly", specifier = ">=0.62.0" }, + { name = "pyrefly", specifier = ">=0.64.0" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "pytest-benchmark", specifier = ">=5.2.3" }, { name = "pytest-cov", specifier = ">=7.1.0" }, @@ -2657,14 +2657,14 @@ wheels = [ [[package]] name = "gitpython" -version = "3.1.47" +version = "3.1.49" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gitdb" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/bd/50db468e9b1310529a19fce651b3b0e753b5c07954d486cba31bbee9a5d5/gitpython-3.1.47.tar.gz", hash = "sha256:dba27f922bd2b42cb54c87a8ab3cb6beb6bf07f3d564e21ac848913a05a8a3cd", size = 216978, upload-time = "2026-04-22T02:44:44.059Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/63/210aaa302d6a0a78daa67c5c15bbac2cad361722841278b0209b6da20855/gitpython-3.1.49.tar.gz", hash = "sha256:42f9399c9eb33fc581014bedd76049dfbaf6375aa2a5754575966387280315e1", size = 219367, upload-time = "2026-04-29T00:31:20.478Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f2/c5/a1bc0996af85757903cf2bf444a7824e68e0035ce63fb41d6f76f9def68b/gitpython-3.1.47-py3-none-any.whl", hash = "sha256:489f590edfd6d20571b2c0e72c6a6ac6915ee8b8cd04572330e3842207a78905", size = 209547, upload-time = "2026-04-22T02:44:41.271Z" }, + { url = "https://files.pythonhosted.org/packages/fd/6f/b842bfa6f21d6f87c57f9abf7194225e55279d96d869775e19e9f7236fc5/gitpython-3.1.49-py3-none-any.whl", hash = "sha256:024b0422d7f84d15cd794844e029ffebd4c5d42a7eb9b936b458697ef550a02c", size = 212190, upload-time = "2026-04-29T00:31:18.412Z" }, ] [[package]] @@ -3740,14 +3740,14 @@ wheels = [ [[package]] name = "mako" -version = "1.3.11" +version = "1.3.12" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/59/8a/805404d0c0b9f3d7a326475ca008db57aea9c5c9f2e1e39ed0faa335571c/mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069", size = 399811, upload-time = "2026-04-14T20:19:51.493Z" } +sdist = { url = "https://files.pythonhosted.org/packages/00/62/791b31e69ae182791ec67f04850f2f062716bbd205483d63a215f3e062d3/mako-1.3.12.tar.gz", hash = "sha256:9f778e93289bd410bb35daadeb4fc66d95a746f0b75777b942088b7fd7af550a", size = 400219, upload-time = "2026-04-28T19:01:08.512Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/a5/19d7aaa7e433713ffe881df33705925a196afb9532efc8475d26593921a6/mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77", size = 78503, upload-time = "2026-04-14T20:19:53.233Z" }, + { url = "https://files.pythonhosted.org/packages/bc/b1/a0ec7a5a9db730a08daef1fdfb8090435b82465abbf758a596f0ea88727e/mako-1.3.12-py3-none-any.whl", hash = "sha256:8f61569480282dbf557145ce441e4ba888be453c30989f879f0d652e39f53ea9", size = 78521, upload-time = "2026-04-28T19:01:10.393Z" }, ] [[package]] @@ -5359,19 +5359,19 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.62.0" +version = "0.64.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bb/ad/8874ed25781e7dd561c6d75fb4a7becf10a18d75b074f25b845cc334f781/pyrefly-0.62.0.tar.gz", hash = "sha256:da1fbe1075dc1e6c8e3134e9370b0a0e7a296061d782cca5bf83dbb8e4c10d7c", size = 5537672, upload-time = "2026-04-20T17:12:15.718Z" } +sdist = { url = "https://files.pythonhosted.org/packages/85/99/923622d7b52ef84e83f357b19bd08dff063ccc5f4472b003105e1f308d93/pyrefly-0.64.0.tar.gz", hash = "sha256:fbfcdb0031adadc340b6c64cb41c6094c95349ee952fe3d4c143866add829172", size = 5678516, upload-time = "2026-05-06T17:28:44.056Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1b/ea/09bd9da7d5df294db800312fb415be2fefbaa5594178e9e49f44fa071aea/pyrefly-0.62.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9d78ec4f126dee1fa76215b193b964490ce10e62a32d2787a72c51623658b803", size = 13020414, upload-time = "2026-04-20T17:11:43.617Z" }, - { url = "https://files.pythonhosted.org/packages/4b/f0/f84afac4f220c4c8c801b779ee2ff28ad3f7731f4283c2e1b6ee9012e8c2/pyrefly-0.62.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2a41a34902d20756264486f9e309f22633d100261bd960feea6e858a098d985d", size = 12515659, upload-time = "2026-04-20T17:11:46.59Z" }, - { url = "https://files.pythonhosted.org/packages/40/0b/620c39cefa9ae1b25ee7a2da9d8d3c278b095649cb8435c5e01ea64f7c17/pyrefly-0.62.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4666c6b65aea662e5f77b64dc91c091b7ea5cede6aa66c0f4cbae26480403583", size = 36228332, upload-time = "2026-04-20T17:11:50.523Z" }, - { url = "https://files.pythonhosted.org/packages/2d/fb/47b8b76438c12761e509a3666cd5a99d4af7f21976ba8385feb475cbfe30/pyrefly-0.62.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1aefab798f47d37c13ded791192fee9b39a6d2b12e31f38ae06a1f80c4b26e22", size = 38995741, upload-time = "2026-04-20T17:11:54.702Z" }, - { url = "https://files.pythonhosted.org/packages/55/d2/03bd17673f61147cd5609cd7d6a1455eeccc17a07a7e141ed9931b0c42c0/pyrefly-0.62.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa986b50d56740da1d7ae7c660a505143cb9d286fa98cc7e5f4a759cc6eaa5d", size = 37205321, upload-time = "2026-04-20T17:11:58.9Z" }, - { url = "https://files.pythonhosted.org/packages/75/14/20ba7b7f2d182f9b7c1e24a3041dac9b5730ae28cfe1614a2c98706650f2/pyrefly-0.62.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:32e9b175805c82ffb967e4708f4910bace7e1a12736907380cc9afdbaabb0efb", size = 41786834, upload-time = "2026-04-20T17:12:03.221Z" }, - { url = "https://files.pythonhosted.org/packages/fa/c8/5a7ba88c4fa1b5090d877f70fa1b742b921b9e7d8d3f4b6b9b1ba1820850/pyrefly-0.62.0-py3-none-win32.whl", hash = "sha256:1cd98edc20cab5bac8016c9220ee66080e39bd22e7f0e9bb3e2c4e2be1555eed", size = 12010170, upload-time = "2026-04-20T17:12:06.791Z" }, - { url = "https://files.pythonhosted.org/packages/2e/78/d8f810de010ff2ed594c630c724fd817ef430963249e9eb396ce8f785e9d/pyrefly-0.62.0-py3-none-win_amd64.whl", hash = "sha256:6994f8ee7d6720325ee52207fbdaca98a799a1efe462bb5ba90c47160f7f3e6e", size = 12861816, upload-time = "2026-04-20T17:12:09.689Z" }, - { url = "https://files.pythonhosted.org/packages/c7/a9/ac824ef6a3f50b7c0ec5974471f8f2cb205cd1edd53a5abbcf7ba37feb5d/pyrefly-0.62.0-py3-none-win_arm64.whl", hash = "sha256:362a5d47a5ac5aaa5258091e878a1759ff8b687d8cf462af1c516144f7b0108a", size = 12352977, upload-time = "2026-04-20T17:12:12.736Z" }, + { url = "https://files.pythonhosted.org/packages/b8/1c/b001b7e84a811dbb3c85e31bd4bfc3edfa3c94438140cd1d6e8c06b7c1df/pyrefly-0.64.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:683b317d8d0e815fb2ad75b7e0fa6c15eed5be4bcbc407dc13312984da3a9c47", size = 13287462, upload-time = "2026-05-06T17:28:19.169Z" }, + { url = "https://files.pythonhosted.org/packages/89/02/1e6fcd311bd7c24aaccc0afb998d584e1fa6c370e1428b4b091103760efe/pyrefly-0.64.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:96913cc4f066a7bd008b9dba8e3951234e92bb8a3a2cb1aea0e274fd2a444c55", size = 12777104, upload-time = "2026-05-06T17:28:22.047Z" }, + { url = "https://files.pythonhosted.org/packages/d6/2b/3f347b8d97c9065d6ace14a22591c8d91e64610e74e0d4f214b3025ebcf7/pyrefly-0.64.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2ae557e1b6a6a5bda844806cae10b212cf84ea786ece10d55083a0321ee1705", size = 37064924, upload-time = "2026-05-06T17:28:24.743Z" }, + { url = "https://files.pythonhosted.org/packages/73/dd/0b40175e930a96139a8e9f62a8e1db7f9a5e9df8e6cef08bf280affcb05e/pyrefly-0.64.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d062ac1744346efacd7df23c6bbff662ad29ed495923cb59ede656a306355655", size = 39719832, upload-time = "2026-05-06T17:28:28.042Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4b/0afb4ad02eb67ddb299ff3f7108ceb307e520578b00e900d07f2371423ca/pyrefly-0.64.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6850b305d45121911fbe25ad56497d2e887b387ea50644ba15a8ad2a8cf855f4", size = 37861666, upload-time = "2026-05-06T17:28:31.234Z" }, + { url = "https://files.pythonhosted.org/packages/e5/1b/f5390f8678433708288afab13f043ddd021a55dba3f665360d2c9396ee04/pyrefly-0.64.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a259925620a84fe87cd30a82643ec524eeef631f0c4ec5af81a21e006c2f5b1", size = 42634235, upload-time = "2026-05-06T17:28:34.405Z" }, + { url = "https://files.pythonhosted.org/packages/47/f7/4b66934e375dde3e4d75373b1a94eb7e7c0c0c788e94267641a223930180/pyrefly-0.64.0-py3-none-win32.whl", hash = "sha256:20317f6dd97e22bc508b8dbc537e59b0ab58e384113ee61920c87ed1a6a12f62", size = 12213388, upload-time = "2026-05-06T17:28:37.146Z" }, + { url = "https://files.pythonhosted.org/packages/0a/15/653523d99795041a1be6dadf7a73225317cb2aae4b21e6df57edbce807f0/pyrefly-0.64.0-py3-none-win_amd64.whl", hash = "sha256:e88fc6a83add9b7c2224be0f74df1b0db10b3af856ae30e4e0a90ba3644c712f", size = 13136719, upload-time = "2026-05-06T17:28:39.767Z" }, + { url = "https://files.pythonhosted.org/packages/50/bb/9ea1c26b511b38a3e1eefc1bd3de7d3f65b2bbfdb59295f3244f61564a81/pyrefly-0.64.0-py3-none-win_arm64.whl", hash = "sha256:73744bd95e836abda0d08e9cdcf008142090ae0124c8f8ff477c944b60c0343c", size = 12526050, upload-time = "2026-05-06T17:28:42.077Z" }, ] [[package]] diff --git a/eslint-suppressions.json b/eslint-suppressions.json index cd37f0ed89..b4876dcf45 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -155,9 +155,6 @@ } }, "web/app/account/(commonLayout)/account-page/email-change-modal.tsx": { - "erasable-syntax-only/enums": { - "count": 1 - }, "ts/no-explicit-any": { "count": 5 } @@ -1824,26 +1821,6 @@ "count": 1 } }, - "web/app/components/base/tag-management/__tests__/panel.spec.tsx": { - "ts/no-explicit-any": { - "count": 2 - } - }, - "web/app/components/base/tag-management/index.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/base/tag-management/tag-item-editor.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, - "web/app/components/base/tag-management/tag-remove-modal.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/base/text-generation/hooks.ts": { "ts/no-explicit-any": { "count": 1 @@ -1921,11 +1898,6 @@ "count": 4 } }, - "web/app/components/billing/plan/index.tsx": { - "ts/no-explicit-any": { - "count": 2 - } - }, "web/app/components/billing/pricing/assets/index.tsx": { "no-barrel-files/no-barrel-files": { "count": 12 @@ -2359,11 +2331,6 @@ "count": 1 } }, - "web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/components/datasets/metadata/edit-metadata-batch/input-combined.tsx": { "ts/no-explicit-any": { "count": 2 @@ -2469,17 +2436,6 @@ "count": 2 } }, - "web/app/components/explore/create-app-modal/index.tsx": { - "no-restricted-imports": { - "count": 1 - }, - "ts/no-explicit-any": { - "count": 1 - }, - "unicorn/prefer-number-properties": { - "count": 1 - } - }, "web/app/components/explore/item-operation/index.tsx": { "react/set-state-in-effect": { "count": 1 @@ -4238,11 +4194,6 @@ "count": 1 } }, - "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-value-method.tsx": { - "no-restricted-imports": { - "count": 1 - } - }, "web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/metadata-filter/index.tsx": { "no-restricted-imports": { "count": 1 @@ -5099,11 +5050,6 @@ "count": 5 } }, - "web/app/education-apply/verify-state-modal.tsx": { - "react/set-state-in-effect": { - "count": 1 - } - }, "web/app/forgot-password/ForgotPasswordForm.spec.tsx": { "ts/no-explicit-any": { "count": 5 @@ -5378,11 +5324,6 @@ "count": 2 } }, - "web/service/knowledge/use-dataset.ts": { - "@tanstack/query/exhaustive-deps": { - "count": 1 - } - }, "web/service/share.ts": { "erasable-syntax-only/enums": { "count": 1 diff --git a/packages/contracts/generated/api/console/apps/types.gen.ts b/packages/contracts/generated/api/console/apps/types.gen.ts index fe4c10329e..4a4742adcf 100644 --- a/packages/contracts/generated/api/console/apps/types.gen.ts +++ b/packages/contracts/generated/api/console/apps/types.gen.ts @@ -4156,8 +4156,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesResponse export type DeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdData = { body?: never path: { - app_id: string variable_id: string + app_id: string } query?: never url: '/apps/{app_id}/workflows/draft/variables/{variable_id}' @@ -4210,8 +4210,8 @@ export type GetAppsByAppIdWorkflowsDraftVariablesByVariableIdResponse export type PatchAppsByAppIdWorkflowsDraftVariablesByVariableIdData = { body: WorkflowDraftVariableUpdatePayload path: { - app_id: string variable_id: string + app_id: string } query?: never url: '/apps/{app_id}/workflows/draft/variables/{variable_id}' diff --git a/packages/contracts/generated/api/console/apps/zod.gen.ts b/packages/contracts/generated/api/console/apps/zod.gen.ts index dcaeaed246..9798d22cc0 100644 --- a/packages/contracts/generated/api/console/apps/zod.gen.ts +++ b/packages/contracts/generated/api/console/apps/zod.gen.ts @@ -2980,8 +2980,8 @@ export const zGetAppsByAppIdWorkflowsDraftVariablesQuery = z.object({ export const zGetAppsByAppIdWorkflowsDraftVariablesResponse = zWorkflowDraftVariableListWithoutValue export const zDeleteAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({ - app_id: z.string(), variable_id: z.string(), + app_id: z.string(), }) /** @@ -3006,8 +3006,8 @@ export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdBody = zWorkflowDraftVariableUpdatePayload export const zPatchAppsByAppIdWorkflowsDraftVariablesByVariableIdPath = z.object({ - app_id: z.string(), variable_id: z.string(), + app_id: z.string(), }) /** diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 89a68593b7..61d380d686 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -255,6 +255,7 @@ export type ProcessRule = { } export type RetrievalModel = { + metadata_filtering_conditions?: MetadataFilteringCondition reranking_enable: boolean reranking_mode?: string | null reranking_model?: RerankingModel @@ -312,6 +313,11 @@ export type Rule = { subchunk_segmentation?: Segmentation } +export type MetadataFilteringCondition = { + conditions?: Array | null + logical_operator?: 'and' | 'or' | null +} + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null @@ -405,6 +411,30 @@ export type Segmentation = { separator?: string } +export type Condition = { + comparison_operator: + | 'contains' + | 'not contains' + | 'start with' + | 'end with' + | 'is' + | 'is not' + | 'empty' + | 'not empty' + | 'in' + | 'not in' + | '=' + | '≠' + | '>' + | '<' + | '≥' + | '≤' + | 'before' + | 'after' + name: string + value?: unknown +} + export type WeightKeywordSetting = { keyword_weight: number } @@ -1174,8 +1204,8 @@ export type PatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse export type DeleteDatasetsByDatasetIdDocumentsByDocumentIdData = { body?: never path: { - dataset_id: string document_id: string + dataset_id: string } query?: never url: '/datasets/{dataset_id}/documents/{document_id}' diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index 2ac2cbfd1f..76491c52a0 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -392,6 +392,46 @@ export const zProcessRule = z.object({ rules: zRule.optional(), }) +/** + * Condition + * + * Condition detail + */ +export const zCondition = z.object({ + comparison_operator: z.enum([ + 'contains', + 'not contains', + 'start with', + 'end with', + 'is', + 'is not', + 'empty', + 'not empty', + 'in', + 'not in', + '=', + '≠', + '>', + '<', + '≥', + '≤', + 'before', + 'after', + ]), + name: z.string(), + value: z.unknown().optional(), +}) + +/** + * MetadataFilteringCondition + * + * Metadata Filtering Condition. + */ +export const zMetadataFilteringCondition = z.object({ + conditions: z.array(zCondition).nullish(), + logical_operator: z.enum(['and', 'or']).nullish().default('and'), +}) + /** * WeightKeywordSetting */ @@ -421,6 +461,7 @@ export const zWeightModel = z.object({ * RetrievalModel */ export const zRetrievalModel = z.object({ + metadata_filtering_conditions: zMetadataFilteringCondition.optional(), reranking_enable: z.boolean(), reranking_mode: z.string().nullish(), reranking_model: zRerankingModel.optional(), @@ -925,8 +966,8 @@ export const zPatchDatasetsByDatasetIdDocumentsStatusByActionBatchResponse = z.r ) export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdPath = z.object({ - dataset_id: z.string(), document_id: z.string(), + dataset_id: z.string(), }) /** diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index f491c1e3f9..e3791e295c 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -325,8 +325,37 @@ export type WorkflowRunResponse = { workflow_id: string } +export type Condition = { + comparison_operator: + | 'contains' + | 'not contains' + | 'start with' + | 'end with' + | 'is' + | 'is not' + | 'empty' + | 'not empty' + | 'in' + | 'not in' + | '=' + | '≠' + | '>' + | '<' + | '≥' + | '≤' + | 'before' + | 'after' + name: string + value?: unknown +} + export type DatasetPermissionEnum = 'only_me' | 'all_team_members' | 'partial_members' +export type MetadataFilteringCondition = { + conditions?: Array | null + logical_operator?: 'and' | 'or' | null +} + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null @@ -339,6 +368,7 @@ export type RetrievalMethod | 'keyword_search' export type RetrievalModel = { + metadata_filtering_conditions?: MetadataFilteringCondition reranking_enable: boolean reranking_mode?: string | null reranking_model?: RerankingModel @@ -1833,8 +1863,8 @@ export type GetDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdData = body?: never path: { segment_id: string - dataset_id: string document_id: string + dataset_id: string } query?: never url: '/datasets/{dataset_id}/documents/{document_id}/segments/{segment_id}' diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 2c2400c0cb..6feacbdead 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -326,11 +326,51 @@ export const zWorkflowRunResponse = z.object({ workflow_id: z.string(), }) +/** + * Condition + * + * Condition detail + */ +export const zCondition = z.object({ + comparison_operator: z.enum([ + 'contains', + 'not contains', + 'start with', + 'end with', + 'is', + 'is not', + 'empty', + 'not empty', + 'in', + 'not in', + '=', + '≠', + '>', + '<', + '≥', + '≤', + 'before', + 'after', + ]), + name: z.string(), + value: z.unknown().optional(), +}) + /** * DatasetPermissionEnum */ export const zDatasetPermissionEnum = z.enum(['only_me', 'all_team_members', 'partial_members']) +/** + * MetadataFilteringCondition + * + * Metadata Filtering Condition. + */ +export const zMetadataFilteringCondition = z.object({ + conditions: z.array(zCondition).nullish(), + logical_operator: z.enum(['and', 'or']).nullish().default('and'), +}) + /** * RerankingModel */ @@ -378,6 +418,7 @@ export const zWeightModel = z.object({ * RetrievalModel */ export const zRetrievalModel = z.object({ + metadata_filtering_conditions: zMetadataFilteringCondition.optional(), reranking_enable: z.boolean(), reranking_mode: z.string().nullish(), reranking_model: zRerankingModel.optional(), @@ -1082,8 +1123,8 @@ export const zDeleteDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdR export const zGetDatasetsByDatasetIdDocumentsByDocumentIdSegmentsBySegmentIdPath = z.object({ segment_id: z.string(), - dataset_id: z.string(), document_id: z.string(), + dataset_id: z.string(), }) /** diff --git a/packages/dev-proxy/README.md b/packages/dev-proxy/README.md new file mode 100644 index 0000000000..6b9d7298c4 --- /dev/null +++ b/packages/dev-proxy/README.md @@ -0,0 +1,196 @@ +# @langgenius/dev-proxy + +Generic Hono-based development proxy for frontend projects. The package does not ship any product-specific routes, cookie names, or environment variable conventions. Every proxied path and upstream target is declared in a local config file. + +## Installation + +```bash +pnpm add -D @langgenius/dev-proxy +``` + +Add a script in your frontend project: + +```json +{ + "scripts": { + "dev:proxy": "dev-proxy --config ./dev-proxy.config.ts --env-file ./.env" + } +} +``` + +Run it with: + +```bash +pnpm dev:proxy +``` + +## CLI + +```bash +dev-proxy --config ./dev-proxy.config.ts +``` + +Supported options: + +- `--config`, `-c`: config file path. Defaults to `dev-proxy.config.ts`. +- `--env-file`: load environment variables before evaluating the config file. +- `--host`: override `server.host` from config. +- `--port`: override `server.port` from config. +- `--help`, `-h`: print help. + +`--target` is not supported. Put targets in the config file so routes and upstreams stay explicit. + +## Config Shape + +```ts +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +export default defineDevProxyConfig({ + server: { + host: '127.0.0.1', + port: 5001, + }, + routes: [ + { + paths: '/api', + target: 'https://example.com', + }, + ], + cors: { + allowedOrigins: 'local', + }, +}) +``` + +Config files can be `.ts`, `.mts`, `.js`, or `.mjs`. + +`routes` are matched in declaration order. The first matching route wins. Each configured path matches both the exact path and all child paths, so `paths: '/api'` matches `/api`, `/api/apps`, and `/api/apps/123`. + +By default, credentialed CORS is allowed for local development origins such as `localhost`, `127.0.0.1`, and `::1`. To restrict it to specific origins: + +``` +cors: { + allowedOrigins: ['http://localhost:3000'], +} +``` + +## Scenario 1: Proxy One Local Route Group To An Online Backend + +Use this when a local frontend should call an online backend through one proxy server. For example, the frontend calls `http://127.0.0.1:5001/api/apps`, and the proxy forwards it to `https://cloud.example.com/api/apps`. + +```ts +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +const target = process.env.DEV_PROXY_TARGET || 'https://cloud.example.com' + +export default defineDevProxyConfig({ + server: { + host: process.env.DEV_PROXY_HOST || '127.0.0.1', + port: Number(process.env.DEV_PROXY_PORT || 5001), + }, + routes: [ + { + paths: '/api', + target, + }, + ], +}) +``` + +Optional `.env`: + +```env +DEV_PROXY_TARGET=https://cloud.example.com +DEV_PROXY_HOST=127.0.0.1 +DEV_PROXY_PORT=5001 +``` + +Command: + +```bash +dev-proxy --config ./dev-proxy.config.ts --env-file ./.env +``` + +## Scenario 2: Proxy Two Route Groups To Two Local Backends + +Use this when one frontend needs to talk to two different local services. For example: + +- `/console/api/*` goes to a local console backend at `http://127.0.0.1:5001` +- `/api/*` goes to a local public API backend at `http://127.0.0.1:5002` + +```ts +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +const consoleApiTarget = process.env.DEV_PROXY_CONSOLE_API_TARGET || 'http://127.0.0.1:5001' +const publicApiTarget = process.env.DEV_PROXY_PUBLIC_API_TARGET || 'http://127.0.0.1:5002' + +export default defineDevProxyConfig({ + server: { + host: process.env.DEV_PROXY_HOST || '127.0.0.1', + port: Number(process.env.DEV_PROXY_PORT || 8082), + }, + routes: [ + { + paths: '/console/api', + target: consoleApiTarget, + }, + { + paths: '/api', + target: publicApiTarget, + }, + ], +}) +``` + +Optional `.env`: + +```env +DEV_PROXY_CONSOLE_API_TARGET=http://127.0.0.1:5001 +DEV_PROXY_PUBLIC_API_TARGET=http://127.0.0.1:5002 +DEV_PROXY_HOST=127.0.0.1 +DEV_PROXY_PORT=8082 +``` + +When two route groups overlap, put the more specific one first: + +```ts +routes: [ + { paths: '/api/enterprise', target: 'http://127.0.0.1:5003' }, + { paths: '/api', target: 'http://127.0.0.1:5002' }, +] +``` + +## Cookie Rewrite + +Cookie rewriting is opt-in and config-driven. The package does not know any application cookie names. + +Use `cookieRewrite` when an upstream uses secure cookie prefixes such as `__Host-` or `__Secure-`, but local development needs cookies to work over `http://localhost`. + +```ts +import type { CookieRewriteOptions } from '@langgenius/dev-proxy' +import { defineDevProxyConfig } from '@langgenius/dev-proxy' + +const cookieRewrite: CookieRewriteOptions = { + hostPrefixCookies: ['access_token', 'refresh_token', /^passport-/], +} + +export default defineDevProxyConfig({ + routes: [ + { + paths: '/api', + target: 'https://cloud.example.com', + cookieRewrite, + }, + ], +}) +``` + +Set `cookieRewrite: false` to disable cookie rewriting for a route. + +## Behavior + +- The proxy preserves the matched path prefix when forwarding requests. +- Request bodies are forwarded as streams. +- Hop-by-hop headers are removed before forwarding. +- Local credentialed CORS and preflight requests are handled by the proxy. +- Route matching is explicit and order-sensitive. diff --git a/packages/dev-proxy/bin/dev-proxy.js b/packages/dev-proxy/bin/dev-proxy.js new file mode 100755 index 0000000000..02e37f3525 --- /dev/null +++ b/packages/dev-proxy/bin/dev-proxy.js @@ -0,0 +1,3 @@ +#!/usr/bin/env node + +import '../dist/cli.mjs' diff --git a/packages/dev-proxy/package.json b/packages/dev-proxy/package.json new file mode 100644 index 0000000000..d5524290eb --- /dev/null +++ b/packages/dev-proxy/package.json @@ -0,0 +1,43 @@ +{ + "name": "@langgenius/dev-proxy", + "type": "module", + "version": "0.0.5", + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs" + } + }, + "types": "./dist/index.d.mts", + "bin": { + "dev-proxy": "./bin/dev-proxy.js" + }, + "files": [ + "bin", + "dist", + "src" + ], + "engines": { + "node": "^22.22.1" + }, + "scripts": { + "build": "vp pack", + "prepare": "pnpm run build", + "test": "vp test", + "type-check": "tsgo", + "prepublish": "pnpm run build" + }, + "dependencies": { + "@hono/node-server": "catalog:", + "c12": "catalog:", + "hono": "catalog:" + }, + "devDependencies": { + "@dify/tsconfig": "workspace:*", + "@types/node": "catalog:", + "@typescript/native-preview": "catalog:", + "vite": "catalog:", + "vite-plus": "catalog:", + "vitest": "catalog:" + } +} diff --git a/packages/dev-proxy/src/cli.spec.ts b/packages/dev-proxy/src/cli.spec.ts new file mode 100644 index 0000000000..e8a87a0588 --- /dev/null +++ b/packages/dev-proxy/src/cli.spec.ts @@ -0,0 +1,158 @@ +/** + * @vitest-environment node + */ +import type { ChildProcessByStdio } from 'node:child_process' +import type { Readable } from 'node:stream' +import { spawn } from 'node:child_process' +import { once } from 'node:events' +import fs from 'node:fs/promises' +import net from 'node:net' +import os from 'node:os' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { afterEach, describe, expect, it } from 'vitest' + +const tempDirs: string[] = [] +type DevProxyCliProcess = ChildProcessByStdio + +const childProcesses: DevProxyCliProcess[] = [] +const binPath = fileURLToPath(new URL('../bin/dev-proxy.js', import.meta.url)) + +const createTempDir = async () => { + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'dev-proxy-cli-test-')) + tempDirs.push(tempDir) + return tempDir +} + +const getFreePort = async () => { + const server = net.createServer() + await new Promise((resolve, reject) => { + server.once('error', reject) + server.listen(0, '127.0.0.1', resolve) + }) + + const address = server.address() + if (!address || typeof address === 'string') + throw new Error('Failed to allocate a test port.') + + const { port } = address + await new Promise((resolve, reject) => { + server.close((error) => { + if (error) + reject(error) + else + resolve() + }) + }) + + return port +} + +const waitForOutput = ( + child: DevProxyCliProcess, + output: () => string, + expectedOutput: string, +) => new Promise((resolve, reject) => { + let timeout: ReturnType + + function cleanup() { + clearTimeout(timeout) + child.stdout.off('data', onData) + child.stderr.off('data', onData) + child.off('exit', onExit) + } + + function onData() { + if (!output().includes(expectedOutput)) + return + + cleanup() + resolve() + } + + function onExit(code: number | null, signal: NodeJS.Signals | null) { + cleanup() + reject(new Error(`dev-proxy exited before writing "${expectedOutput}" with code ${code} and signal ${signal}. Output:\n${output()}`)) + } + + timeout = setTimeout(() => { + cleanup() + reject(new Error(`Timed out waiting for "${expectedOutput}". Output:\n${output()}`)) + }, 3000) + + child.stdout.on('data', onData) + child.stderr.on('data', onData) + child.once('exit', onExit) + onData() +}) + +const spawnCli = (args: readonly string[], cwd: string) => { + const child = spawn(process.execPath, [binPath, ...args], { + cwd, + env: { + ...process.env, + FORCE_COLOR: '0', + }, + stdio: ['ignore', 'pipe', 'pipe'], + }) + childProcesses.push(child) + return child +} + +const stopChildProcess = async (child: DevProxyCliProcess) => { + if (child.exitCode !== null || child.signalCode !== null) + return + + child.kill('SIGTERM') + await once(child, 'exit') +} + +describe('dev proxy CLI', () => { + afterEach(async () => { + await Promise.all(childProcesses.splice(0).map(stopChildProcess)) + await Promise.all(tempDirs.splice(0).map(tempDir => fs.rm(tempDir, { + force: true, + recursive: true, + }))) + }) + + // Scenario: help output should still be a normal short-lived command. + it('should print help and exit', async () => { + // Arrange + const tempDir = await createTempDir() + const child = spawnCli(['--help'], tempDir) + + // Act + const [code] = await once(child, 'exit') + + // Assert + expect(code).toBe(0) + }) + + // Scenario: successful server startup should keep the CLI process alive. + it('should keep running after starting the proxy server', async () => { + // Arrange + const tempDir = await createTempDir() + const port = await getFreePort() + await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), ` + export default { + routes: [{ paths: '/api', target: 'https://api.example.com' }], + } + `) + + let output = '' + const child = spawnCli(['--config', './dev-proxy.config.ts', '--host', '127.0.0.1', '--port', String(port)], tempDir) + child.stdout.on('data', chunk => output += chunk.toString()) + child.stderr.on('data', chunk => output += chunk.toString()) + + // Act + await waitForOutput(child, () => output, `[dev-proxy] listening on http://127.0.0.1:${port}`) + await new Promise(resolve => setTimeout(resolve, 100)) + const response = await fetch(`http://127.0.0.1:${port}/not-proxied`) + + // Assert + expect(child.exitCode).toBeNull() + expect(child.signalCode).toBeNull() + expect(response.status).toBe(404) + }) +}) diff --git a/packages/dev-proxy/src/cli.ts b/packages/dev-proxy/src/cli.ts new file mode 100644 index 0000000000..05234cb359 --- /dev/null +++ b/packages/dev-proxy/src/cli.ts @@ -0,0 +1,56 @@ +import process from 'node:process' +import { serve } from '@hono/node-server' +import { loadDevProxyConfig, parseDevProxyCliArgs, resolveDevProxyServerOptions } from './config' +import { createDevProxyApp } from './server' + +function printUsage() { + console.log(`Usage: + dev-proxy --config [options] + +Options: + --config, -c Path to a dev proxy config file. Defaults to dev-proxy.config.ts. + --env-file Load environment variables before evaluating the config file. + --host Override the configured host. + --port Override the configured port. + --help, -h Show this help message.`) +} + +async function flushStandardStreams() { + await Promise.all([ + new Promise(resolve => process.stdout.write('', () => resolve())), + new Promise(resolve => process.stderr.write('', () => resolve())), + ]) +} + +async function main() { + const cliOptions = parseDevProxyCliArgs(process.argv.slice(2)) + + if (cliOptions.help) { + printUsage() + return + } + + const config = await loadDevProxyConfig(cliOptions.config, process.cwd(), { + envFile: cliOptions.envFile, + }) + const { host, port } = resolveDevProxyServerOptions(config.server, cliOptions) + const app = createDevProxyApp(config) + + serve({ + fetch: app.fetch, + hostname: host, + port, + }) + + console.log(`[dev-proxy] listening on http://${host}:${port}`) +} + +try { + await main() + await flushStandardStreams() +} +catch (error) { + console.error(error instanceof Error ? error.message : error) + await flushStandardStreams() + process.exit(1) +} diff --git a/packages/dev-proxy/src/config.spec.ts b/packages/dev-proxy/src/config.spec.ts new file mode 100644 index 0000000000..6f681bcbae --- /dev/null +++ b/packages/dev-proxy/src/config.spec.ts @@ -0,0 +1,145 @@ +/** + * @vitest-environment node + */ +import fs from 'node:fs/promises' +import os from 'node:os' +import path from 'node:path' +import { afterEach, describe, expect, it } from 'vitest' +import { loadDevProxyConfig, parseDevProxyCliArgs, resolveDevProxyServerOptions } from './config' + +const tempDirs: string[] = [] + +const createTempDir = async () => { + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'dev-proxy-test-')) + tempDirs.push(tempDir) + return tempDir +} + +describe('dev proxy config', () => { + afterEach(async () => { + delete process.env.DEV_PROXY_TEST_PORT + delete process.env.DEV_PROXY_TEST_TARGET + + await Promise.all(tempDirs.splice(0).map(tempDir => fs.rm(tempDir, { + force: true, + recursive: true, + }))) + }) + + // Scenario: CLI options should support both inline and separated values. + it('should parse proxy CLI options', () => { + // Act + const options = parseDevProxyCliArgs([ + '--config=./dev-proxy.config.ts', + '--env-file', + './.env.proxy', + '--host', + '0.0.0.0', + '--port', + '8083', + ]) + + // Assert + expect(options).toEqual({ + config: './dev-proxy.config.ts', + envFile: './.env.proxy', + host: '0.0.0.0', + port: '8083', + }) + }) + + // Scenario: removed target shortcuts should fail instead of silently doing the wrong thing. + it('should reject unsupported target shortcuts', () => { + // Assert + expect(() => parseDevProxyCliArgs(['--target', 'enterprise'])).toThrow('Unsupported dev proxy option') + }) + + // Scenario: package manager argument separators should not be treated as proxy options. + it('should ignore package manager argument separators', () => { + // Act + const options = parseDevProxyCliArgs(['--config', './dev-proxy.config.ts', '--', '--help']) + + // Assert + expect(options).toEqual({ + config: './dev-proxy.config.ts', + help: true, + }) + }) + + // Scenario: CLI host and port should override config defaults. + it('should resolve server options with CLI overrides', () => { + // Act + const options = resolveDevProxyServerOptions({ + host: '127.0.0.1', + port: 5001, + }, { + host: '0.0.0.0', + port: '9002', + }) + + // Assert + expect(options).toEqual({ + host: '0.0.0.0', + port: 9002, + }) + }) + + // Scenario: TS config files should load through c12. + it('should load a TypeScript config file', async () => { + // Arrange + const tempDir = await createTempDir() + await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), ` + export default { + server: { host: '127.0.0.1', port: 7777 }, + routes: [{ paths: ['/api', '/files'], target: 'https://api.example.com' }], + } + `) + + // Act + const config = await loadDevProxyConfig('dev-proxy.config.ts', tempDir) + + // Assert + expect(config.server).toEqual({ + host: '127.0.0.1', + port: 7777, + }) + expect(config.routes).toEqual([ + { + paths: ['/api', '/files'], + target: 'https://api.example.com', + }, + ]) + }) + + // Scenario: env files should be loaded before the TypeScript config is evaluated. + it('should load a TypeScript config file with env file values', async () => { + // Arrange + const tempDir = await createTempDir() + await fs.writeFile(path.join(tempDir, '.env.proxy'), [ + 'DEV_PROXY_TEST_PORT=7788', + 'DEV_PROXY_TEST_TARGET=https://env.example.com', + ].join('\n')) + await fs.writeFile(path.join(tempDir, 'dev-proxy.config.ts'), ` + export default { + server: { port: Number(process.env.DEV_PROXY_TEST_PORT) }, + routes: [{ paths: '/api', target: process.env.DEV_PROXY_TEST_TARGET }], + } + `) + + // Act + const config = await loadDevProxyConfig('dev-proxy.config.ts', tempDir, { + envFile: '.env.proxy', + }) + + // Assert + expect(config.server).toEqual({ + port: 7788, + }) + expect(config.routes).toEqual([ + { + paths: '/api', + target: 'https://env.example.com', + }, + ]) + }) +}) diff --git a/packages/dev-proxy/src/config.ts b/packages/dev-proxy/src/config.ts new file mode 100644 index 0000000000..b23cb0a152 --- /dev/null +++ b/packages/dev-proxy/src/config.ts @@ -0,0 +1,129 @@ +import type { DotenvOptions } from 'c12' +import type { DevProxyCliOptions, DevProxyConfig, DevProxyConfigLoadOptions, DevProxyServerConfig, ResolvedDevProxyServerOptions } from './types' +import path from 'node:path' +import { loadConfig } from 'c12' + +const DEFAULT_CONFIG_FILE = 'dev-proxy.config.ts' +const DEFAULT_PROXY_HOST = '127.0.0.1' +const DEFAULT_PROXY_PORT = 5001 + +const OPTION_NAME_TO_KEY = { + '--config': 'config', + '-c': 'config', + '--env-file': 'envFile', + '--host': 'host', + '--port': 'port', +} as const + +type OptionName = keyof typeof OPTION_NAME_TO_KEY + +const isOptionName = (value: string): value is OptionName => value in OPTION_NAME_TO_KEY + +const requireOptionValue = (name: string, value?: string) => { + if (!value || value.startsWith('-')) + throw new Error(`Missing value for ${name}.`) + + return value +} + +export const parseDevProxyCliArgs = (argv: readonly string[]): DevProxyCliOptions => { + const options: DevProxyCliOptions = {} + + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index]! + + if (arg === '--') + continue + + if (arg === '--help' || arg === '-h') { + options.help = true + continue + } + + const [rawName, inlineValue] = arg.split('=', 2) + const name = rawName ?? '' + + if (!name.startsWith('-')) + continue + + if (!isOptionName(name)) + throw new Error(`Unsupported dev proxy option "${name}".`) + + const key = OPTION_NAME_TO_KEY[name] + options[key] = inlineValue ?? requireOptionValue(name, argv[index + 1]) + + if (inlineValue === undefined) + index += 1 + } + + return options +} + +const resolvePort = (rawPort: string | number) => { + const port = Number(rawPort) + if (!Number.isInteger(port) || port < 1 || port > 65535) + throw new Error(`Invalid proxy port "${rawPort}". Expected an integer between 1 and 65535.`) + + return port +} + +export const resolveDevProxyServerOptions = ( + serverConfig: DevProxyServerConfig = {}, + cliOptions: DevProxyCliOptions = {}, +): ResolvedDevProxyServerOptions => { + const configuredPort = cliOptions.port ?? serverConfig.port ?? DEFAULT_PROXY_PORT + + return { + host: cliOptions.host || serverConfig.host || DEFAULT_PROXY_HOST, + port: resolvePort(configuredPort), + } +} + +const isRecord = (value: unknown): value is Record => + typeof value === 'object' && value !== null + +export function assertDevProxyConfig(config: unknown): asserts config is DevProxyConfig { + if (!isRecord(config)) + throw new Error('Dev proxy config must export an object.') + + if (!Array.isArray(config.routes)) + throw new Error('Dev proxy config must include a routes array.') +} + +const resolveDotenvOptions = ( + envFile: DevProxyConfigLoadOptions['envFile'], + cwd: string, +): DotenvOptions | false => { + if (!envFile) + return false + + const resolvedEnvFilePath = path.resolve(cwd, envFile) + return { + cwd: path.dirname(resolvedEnvFilePath), + fileName: path.basename(resolvedEnvFilePath), + interpolate: true, + } +} + +export const loadDevProxyConfig = async ( + configPath = DEFAULT_CONFIG_FILE, + cwd = process.cwd(), + options: DevProxyConfigLoadOptions = {}, +): Promise => { + const resolvedConfigPath = path.resolve(cwd, configPath) + const parsedPath = path.parse(resolvedConfigPath) + const { config: loadedConfig } = await loadConfig({ + configFile: parsedPath.name, + cwd: parsedPath.dir, + dotenv: resolveDotenvOptions(options.envFile, cwd), + envName: false, + globalRc: false, + packageJson: false, + rcFile: false, + }) + + assertDevProxyConfig(loadedConfig) + return loadedConfig +} + +export const defineDevProxyConfig = (config: DevProxyConfig) => config diff --git a/packages/dev-proxy/src/cookies.spec.ts b/packages/dev-proxy/src/cookies.spec.ts new file mode 100644 index 0000000000..4a1b614eeb --- /dev/null +++ b/packages/dev-proxy/src/cookies.spec.ts @@ -0,0 +1,44 @@ +/** + * @vitest-environment node + */ +import { describe, expect, it } from 'vitest' +import { rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies' + +describe('dev proxy cookies', () => { + // Scenario: cookie names should only receive secure host prefixes when configured. + it('should rewrite configured cookie names for HTTPS upstream requests', () => { + // Act + const cookieHeader = rewriteCookieHeaderForUpstream('access_token=abc; theme=dark; passport-app=def', { + hostPrefixCookies: ['access_token', /^passport-/], + useHostPrefix: true, + }) + + // Assert + expect(cookieHeader).toBe('__Host-access_token=abc; theme=dark; __Host-passport-app=def') + }) + + // Scenario: HTTP upstreams should keep local cookie names even when rewrite config exists. + it('should keep local cookie names for HTTP upstream requests', () => { + // Act + const cookieHeader = rewriteCookieHeaderForUpstream('access_token=abc; refresh_token=def', { + hostPrefixCookies: ['access_token', 'refresh_token'], + useHostPrefix: false, + }) + + // Assert + expect(cookieHeader).toBe('access_token=abc; refresh_token=def') + }) + + // Scenario: upstream set-cookie headers should be converted into localhost-safe cookies. + it('should rewrite upstream set-cookie headers for local development', () => { + // Act + const cookies = rewriteSetCookieHeadersForLocal([ + '__Host-access_token=abc; Path=/console/api; Domain=cloud.example.com; Secure; SameSite=None; Partitioned', + ]) + + // Assert + expect(cookies).toEqual([ + 'access_token=abc; Path=/; SameSite=Lax', + ]) + }) +}) diff --git a/web/plugins/dev-proxy/cookies.ts b/packages/dev-proxy/src/cookies.ts similarity index 61% rename from web/plugins/dev-proxy/cookies.ts rename to packages/dev-proxy/src/cookies.ts index ad087d1549..61fdb6abd4 100644 --- a/web/plugins/dev-proxy/cookies.ts +++ b/packages/dev-proxy/src/cookies.ts @@ -1,4 +1,4 @@ -const DEFAULT_PROXY_TARGET = 'https://cloud.dify.ai' +import type { CookieRewriteOptions } from './types' const SECURE_COOKIE_PREFIX_PATTERN = /^__(Host|Secure)-/ const SAME_SITE_NONE_PATTERN = /^samesite=none$/i @@ -7,38 +7,37 @@ const COOKIE_DOMAIN_PATTERN = /^domain=/i const COOKIE_SECURE_PATTERN = /^secure$/i const COOKIE_PARTITIONED_PATTERN = /^partitioned$/i -const HOST_PREFIX_COOKIE_NAMES = new Set([ - 'access_token', - 'csrf_token', - 'refresh_token', - 'webapp_access_token', -]) +const stripSecureCookiePrefix = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '') -const isPassportCookie = (cookieName: string) => cookieName.startsWith('passport-') +const matchesCookieName = (cookieName: string, matcher: string | RegExp) => + typeof matcher === 'string' + ? matcher === cookieName + : matcher.test(cookieName) -const shouldUseHostPrefix = (cookieName: string) => { - const normalizedCookieName = cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '') - return HOST_PREFIX_COOKIE_NAMES.has(normalizedCookieName) || isPassportCookie(normalizedCookieName) +const shouldUseHostPrefix = (cookieName: string, options: CookieRewriteOptions) => { + const normalizedCookieName = stripSecureCookiePrefix(cookieName) + + return options.hostPrefixCookies?.some(matcher => matchesCookieName(normalizedCookieName, matcher)) || false } -const toUpstreamCookieName = (cookieName: string) => { +const toUpstreamCookieName = (cookieName: string, options: CookieRewriteOptions) => { if (cookieName.startsWith('__Host-')) return cookieName if (cookieName.startsWith('__Secure-')) - return `__Host-${cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')}` + return `__Host-${stripSecureCookiePrefix(cookieName)}` - if (!shouldUseHostPrefix(cookieName)) + if (!shouldUseHostPrefix(cookieName, options)) return cookieName return `__Host-${cookieName}` } -const toLocalCookieName = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '') +export const toLocalCookieName = (cookieName: string) => stripSecureCookiePrefix(cookieName) export const rewriteCookieHeaderForUpstream = ( - cookieHeader?: string, - options: { useHostPrefix?: boolean } = {}, + cookieHeader: string | undefined, + options: CookieRewriteOptions & { useHostPrefix?: boolean }, ) => { if (!cookieHeader) return cookieHeader @@ -55,7 +54,11 @@ export const rewriteCookieHeaderForUpstream = ( const cookieName = cookie.slice(0, separatorIndex).trim() const cookieValue = cookie.slice(separatorIndex + 1) - return `${useHostPrefix ? toUpstreamCookieName(cookieName) : cookieName}=${cookieValue}` + const upstreamCookieName = useHostPrefix + ? toUpstreamCookieName(cookieName, options) + : cookieName + + return `${upstreamCookieName}=${cookieValue}` }) .join('; ') } @@ -89,15 +92,5 @@ const rewriteSetCookieValueForLocal = (setCookieValue: string) => { return [`${toLocalCookieName(cookieName)}=${cookieValue}`, ...rewrittenAttributes].join('; ') } -export const rewriteSetCookieHeadersForLocal = (setCookieHeaders?: string | string[]): string[] | undefined => { - if (!setCookieHeaders) - return undefined - - const normalizedHeaders = Array.isArray(setCookieHeaders) - ? setCookieHeaders - : [setCookieHeaders] - - return normalizedHeaders.map(rewriteSetCookieValueForLocal) -} - -export { DEFAULT_PROXY_TARGET } +export const rewriteSetCookieHeadersForLocal = (setCookieHeaders: readonly string[]) => + setCookieHeaders.map(rewriteSetCookieValueForLocal) diff --git a/packages/dev-proxy/src/index.ts b/packages/dev-proxy/src/index.ts new file mode 100644 index 0000000000..e35893b98f --- /dev/null +++ b/packages/dev-proxy/src/index.ts @@ -0,0 +1,22 @@ +export { + assertDevProxyConfig, + defineDevProxyConfig, + loadDevProxyConfig, + parseDevProxyCliArgs, + resolveDevProxyServerOptions, +} from './config' +export { rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal, toLocalCookieName } from './cookies' +export { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, isAllowedLocalDevOrigin } from './server' +export type { + CookieNameMatcher, + CookieRewriteOptions, + CreateDevProxyAppOptions, + DevProxyCliOptions, + DevProxyConfig, + DevProxyConfigLoadOptions, + DevProxyCorsAllowedOrigins, + DevProxyCorsConfig, + DevProxyRoute, + DevProxyServerConfig, + ResolvedDevProxyServerOptions, +} from './types' diff --git a/web/plugins/dev-proxy/server.spec.ts b/packages/dev-proxy/src/server.spec.ts similarity index 54% rename from web/plugins/dev-proxy/server.spec.ts rename to packages/dev-proxy/src/server.spec.ts index 4b3344be42..32c16a1807 100644 --- a/web/plugins/dev-proxy/server.spec.ts +++ b/packages/dev-proxy/src/server.spec.ts @@ -2,41 +2,13 @@ * @vitest-environment node */ import { beforeEach, describe, expect, it, vi } from 'vitest' -import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server' +import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin } from './server' describe('dev proxy server', () => { beforeEach(() => { vi.clearAllMocks() }) - // Scenario: Hono proxy targets should be read directly from env. - it('should resolve Hono proxy targets from env', () => { - // Arrange - const targets = resolveDevProxyTargets({ - HONO_CONSOLE_API_PROXY_TARGET: 'https://console.example.com', - HONO_PUBLIC_API_PROXY_TARGET: 'https://public.example.com', - HONO_ENTERPRISE_API_PROXY_TARGET: 'https://enterprise.example.com', - }) - - // Assert - expect(targets.consoleApiTarget).toBe('https://console.example.com') - expect(targets.publicApiTarget).toBe('https://public.example.com') - expect(targets.enterpriseApiTarget).toBe('https://enterprise.example.com') - }) - - // Scenario: optional proxy targets should use their route-specific defaults. - it('should use console target as the default for optional targets', () => { - // Act - const targets = resolveDevProxyTargets({ - HONO_CONSOLE_API_PROXY_TARGET: 'https://console.example.com', - }) - - // Assert - expect(targets.consoleApiTarget).toBe('https://console.example.com') - expect(targets.publicApiTarget).toBe('https://console.example.com') - expect(targets.enterpriseApiTarget).toBeUndefined() - }) - // Scenario: target paths should not be duplicated when the incoming route already includes them. it('should preserve prefixed targets when building upstream URLs', () => { // Act @@ -46,30 +18,43 @@ describe('dev proxy server', () => { expect(url.href).toBe('https://api.example.com/console/api/apps?page=1') }) - // Scenario: only localhost dev origins should be reflected for credentialed CORS. - it('should only allow local development origins', () => { + // Scenario: only localhost dev origins should be reflected for credentialed CORS by default. + it('should only allow local development origins by default', () => { // Assert expect(isAllowedDevOrigin('http://localhost:3000')).toBe(true) expect(isAllowedDevOrigin('http://127.0.0.1:3000')).toBe(true) expect(isAllowedDevOrigin('https://example.com')).toBe(false) }) - // Scenario: proxy requests should rewrite cookies and surface credentialed CORS headers. - it('should proxy api requests through Hono with local cookie rewriting', async () => { + // Scenario: explicit CORS origins should support non-local development hosts. + it('should allow explicitly configured origins', () => { + // Assert + expect(isAllowedDevOrigin('https://app.example.com', ['https://app.example.com'])).toBe(true) + expect(isAllowedDevOrigin('https://other.example.com', ['https://app.example.com'])).toBe(false) + }) + + // Scenario: proxy requests should rewrite cookies and surface credentialed CORS headers when configured. + it('should proxy api requests with configured local cookie rewriting', async () => { // Arrange const fetchImpl = vi.fn().mockResolvedValue(new Response('ok', { status: 200, headers: [ ['content-encoding', 'br'], ['content-length', '123'], - ['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None'], + ['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.example.com; Secure; SameSite=None'], ['transfer-encoding', 'chunked'], ], })) const app = createDevProxyApp({ - consoleApiTarget: 'https://cloud.dify.ai', - publicApiTarget: 'https://public.dify.ai', - enterpriseApiTarget: 'https://enterprise.dify.ai', + routes: [ + { + paths: '/console/api', + target: 'https://cloud.example.com', + cookieRewrite: { + hostPrefixCookies: ['access_token'], + }, + }, + ], fetchImpl, }) @@ -77,7 +62,7 @@ describe('dev proxy server', () => { const response = await app.request('http://127.0.0.1:5001/console/api/apps?page=1', { headers: { 'Origin': 'http://localhost:3000', - 'Cookie': 'access_token=abc', + 'Cookie': 'access_token=abc; theme=dark', 'Accept-Encoding': 'zstd, br, gzip', }, }) @@ -85,7 +70,7 @@ describe('dev proxy server', () => { // Assert expect(fetchImpl).toHaveBeenCalledTimes(1) expect(fetchImpl).toHaveBeenCalledWith( - new URL('https://cloud.dify.ai/console/api/apps?page=1'), + new URL('https://cloud.example.com/console/api/apps?page=1'), expect.objectContaining({ method: 'GET', headers: expect.any(Headers), @@ -96,8 +81,8 @@ describe('dev proxy server', () => { if (!(requestHeaders instanceof Headers)) throw new Error('Expected proxy request headers to be Headers') - expect(requestHeaders.get('cookie')).toBe('__Host-access_token=abc') - expect(requestHeaders.get('origin')).toBe('https://cloud.dify.ai') + expect(requestHeaders.get('cookie')).toBe('__Host-access_token=abc; theme=dark') + expect(requestHeaders.get('origin')).toBe('https://cloud.example.com') expect(requestHeaders.get('accept-encoding')).toBe('identity') expect(response.headers.get('access-control-allow-origin')).toBe('http://localhost:3000') expect(response.headers.get('access-control-allow-credentials')).toBe('true') @@ -109,14 +94,49 @@ describe('dev proxy server', () => { ]) }) - // Scenario: a local HTTP Dify API expects the non-prefixed local cookie name. + // Scenario: generic proxy routes should not know Dify cookie names by default. + it('should not rewrite cookie names when cookie rewriting is not configured', async () => { + // Arrange + const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) + const app = createDevProxyApp({ + routes: [ + { + paths: '/api', + target: 'https://api.example.com', + }, + ], + fetchImpl, + }) + + // Act + await app.request('http://127.0.0.1:5001/api/messages', { + headers: { + Cookie: 'access_token=abc; refresh_token=def', + }, + }) + + // Assert + const requestHeaders = fetchImpl.mock.calls[0]?.[1]?.headers + if (!(requestHeaders instanceof Headers)) + throw new Error('Expected proxy request headers to be Headers') + + expect(requestHeaders.get('cookie')).toBe('access_token=abc; refresh_token=def') + }) + + // Scenario: local HTTP upstreams expect local cookie names even when cookie rewriting is configured. it('should keep local cookie names for HTTP upstream targets', async () => { // Arrange const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) const app = createDevProxyApp({ - consoleApiTarget: 'http://127.0.0.1:5001', - publicApiTarget: 'http://127.0.0.1:5001', - enterpriseApiTarget: 'http://127.0.0.1:8082', + routes: [ + { + paths: '/console/api', + target: 'http://127.0.0.1:5001', + cookieRewrite: { + hostPrefixCookies: ['access_token', 'refresh_token'], + }, + }, + ], fetchImpl, }) @@ -135,47 +155,59 @@ describe('dev proxy server', () => { expect(requestHeaders.get('cookie')).toBe('access_token=abc; refresh_token=def') }) - // Scenario: Enterprise dashboard routes should use the Enterprise target before generic API routes. - it('should proxy enterprise api routes to the enterprise target', async () => { + // Scenario: custom route paths should support independent upstream targets. + it('should proxy custom route paths to their configured targets', async () => { // Arrange const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) const app = createDevProxyApp({ - consoleApiTarget: 'https://console.example.com', - publicApiTarget: 'https://public.example.com', - enterpriseApiTarget: 'https://enterprise.example.com', + routes: [ + { + paths: '/api', + target: 'https://api.example.com', + }, + { + paths: '/files', + target: 'https://files.example.com/assets', + }, + ], fetchImpl, }) - const requestUrls = [ - 'http://127.0.0.1:5001/console/api/enterprise/sso/saml/login', - 'http://127.0.0.1:5001/api/enterprise/sso/oauth2/login', - 'http://127.0.0.1:5001/admin-api/v1/workspaces', - 'http://127.0.0.1:5001/inner/api/info', - 'http://127.0.0.1:5001/mfa/v1/verify', - 'http://127.0.0.1:5001/scim/v2/Users', - 'http://127.0.0.1:5001/v1/audit/logs', - 'http://127.0.0.1:5001/v1/dashboard/api/license/status', - 'http://127.0.0.1:5001/v1/healthz', - 'http://127.0.0.1:5001/v1/plugin-manager/plugins', - ] - // Act - for (const url of requestUrls) - await app.request(url) + await app.request('http://127.0.0.1:5001/api/messages') + await app.request('http://127.0.0.1:5001/files/logo.png?size=small') // Assert - expect(fetchImpl).toHaveBeenCalledTimes(requestUrls.length) expect(fetchImpl.mock.calls.map(([url]) => url.toString())).toEqual([ - 'https://enterprise.example.com/console/api/enterprise/sso/saml/login', - 'https://enterprise.example.com/api/enterprise/sso/oauth2/login', - 'https://enterprise.example.com/admin-api/v1/workspaces', - 'https://enterprise.example.com/inner/api/info', - 'https://enterprise.example.com/mfa/v1/verify', - 'https://enterprise.example.com/scim/v2/Users', - 'https://enterprise.example.com/v1/audit/logs', - 'https://enterprise.example.com/v1/dashboard/api/license/status', - 'https://enterprise.example.com/v1/healthz', - 'https://enterprise.example.com/v1/plugin-manager/plugins', + 'https://api.example.com/api/messages', + 'https://files.example.com/assets/files/logo.png?size=small', + ]) + }) + + // Scenario: routes are matched in config order so callers can put specific routes first. + it('should prefer earlier route entries', async () => { + // Arrange + const fetchImpl = vi.fn().mockResolvedValue(new Response('ok')) + const app = createDevProxyApp({ + routes: [ + { + paths: '/api/enterprise', + target: 'https://enterprise.example.com', + }, + { + paths: '/api', + target: 'https://api.example.com', + }, + ], + fetchImpl, + }) + + // Act + await app.request('http://127.0.0.1:5001/api/enterprise/sso/login') + + // Assert + expect(fetchImpl.mock.calls.map(([url]) => url.toString())).toEqual([ + 'https://enterprise.example.com/api/enterprise/sso/login', ]) }) @@ -183,9 +215,12 @@ describe('dev proxy server', () => { it('should answer CORS preflight requests', async () => { // Arrange const app = createDevProxyApp({ - consoleApiTarget: 'https://cloud.dify.ai', - publicApiTarget: 'https://public.dify.ai', - enterpriseApiTarget: 'https://enterprise.dify.ai', + routes: [ + { + paths: '/api', + target: 'https://api.example.com', + }, + ], fetchImpl: vi.fn(), }) diff --git a/web/plugins/dev-proxy/server.ts b/packages/dev-proxy/src/server.ts similarity index 52% rename from web/plugins/dev-proxy/server.ts rename to packages/dev-proxy/src/server.ts index e4867b6077..79654750da 100644 --- a/web/plugins/dev-proxy/server.ts +++ b/packages/dev-proxy/src/server.ts @@ -1,25 +1,9 @@ import type { Context, Hono } from 'hono' +import type { CookieRewriteOptions, CreateDevProxyAppOptions, DevProxyCorsAllowedOrigins, DevProxyRoute } from './types' import { Hono as HonoApp } from 'hono' -import { DEFAULT_PROXY_TARGET, rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies' +import { rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies' -type DevProxyEnv = Partial> - -type DevProxyTargets = { - consoleApiTarget: string - publicApiTarget: string - enterpriseApiTarget?: string -} - -type DevProxyAppOptions = DevProxyTargets & { - fetchImpl?: typeof globalThis.fetch -} - -const LOCAL_DEV_HOSTS = new Set(['localhost', '127.0.0.1', '[::1]']) +const LOCAL_DEV_HOSTS = new Set(['localhost', '127.0.0.1', '[::1]', '::1']) const ALLOW_METHODS = 'GET,HEAD,POST,PUT,PATCH,DELETE,OPTIONS' const DEFAULT_ALLOW_HEADERS = 'Authorization, Content-Type, X-CSRF-Token' const UPSTREAM_ACCEPT_ENCODING = 'identity' @@ -28,31 +12,14 @@ const RESPONSE_HEADERS_TO_DROP = [ 'content-encoding', 'content-length', 'keep-alive', - 'set-cookie', + 'proxy-authenticate', + 'proxy-authorization', + 'te', + 'trailer', 'transfer-encoding', + 'upgrade', ] as const -const ENTERPRISE_API_ROUTES = [ - '/console/api/enterprise', - '/api/enterprise', - '/admin-api', - '/inner/api', - '/mfa', - '/scim', - '/v1/audit', - '/v1/dashboard', - '/v1/healthz', - '/v1/plugin-manager', -] as const - -const CONSOLE_API_ROUTES = ['/console/api'] as const -const PUBLIC_API_ROUTES = ['/api'] as const - -type ProxyRoutePath - = | typeof ENTERPRISE_API_ROUTES[number] - | typeof CONSOLE_API_ROUTES[number] - | typeof PUBLIC_API_ROUTES[number] - const appendHeaderValue = (headers: Headers, name: string, value: string) => { const currentValue = headers.get(name) if (!currentValue) { @@ -66,7 +33,7 @@ const appendHeaderValue = (headers: Headers, name: string, value: string) => { headers.set(name, `${currentValue}, ${value}`) } -export const isAllowedDevOrigin = (origin?: string | null) => { +export const isAllowedLocalDevOrigin = (origin?: string | null) => { if (!origin) return false @@ -79,8 +46,25 @@ export const isAllowedDevOrigin = (origin?: string | null) => { } } -const applyCorsHeaders = (headers: Headers, origin?: string | null) => { - if (!isAllowedDevOrigin(origin)) +export const isAllowedDevOrigin = ( + origin?: string | null, + allowedOrigins: DevProxyCorsAllowedOrigins = 'local', +) => { + if (!origin) + return false + + if (allowedOrigins === 'local') + return isAllowedLocalDevOrigin(origin) + + return allowedOrigins.includes(origin) +} + +const applyCorsHeaders = ( + headers: Headers, + origin: string | undefined | null, + allowedOrigins: DevProxyCorsAllowedOrigins = 'local', +) => { + if (!isAllowedDevOrigin(origin, allowedOrigins)) return headers.set('Access-Control-Allow-Origin', origin!) @@ -103,7 +87,11 @@ export const buildUpstreamUrl = (target: string, requestPath: string, search = ' return targetUrl } -const createProxyRequestHeaders = (request: Request, targetUrl: URL) => { +const createProxyRequestHeaders = ( + request: Request, + targetUrl: URL, + cookieRewrite: CookieRewriteOptions | false | undefined, +) => { const headers = new Headers(request.headers) headers.delete('host') headers.set('accept-encoding', UPSTREAM_ACCEPT_ENCODING) @@ -111,36 +99,60 @@ const createProxyRequestHeaders = (request: Request, targetUrl: URL) => { if (headers.has('origin')) headers.set('origin', targetUrl.origin) - const rewrittenCookieHeader = rewriteCookieHeaderForUpstream(headers.get('cookie') || undefined, { - useHostPrefix: targetUrl.protocol === 'https:', - }) - if (rewrittenCookieHeader) - headers.set('cookie', rewrittenCookieHeader) + if (cookieRewrite) { + const rewrittenCookieHeader = rewriteCookieHeaderForUpstream(headers.get('cookie') || undefined, { + ...cookieRewrite, + useHostPrefix: targetUrl.protocol === 'https:', + }) + if (rewrittenCookieHeader) + headers.set('cookie', rewrittenCookieHeader) + } return headers } -const createUpstreamResponseHeaders = (response: Response, requestOrigin?: string | null) => { +const getSetCookieHeaders = (headers: Headers) => { + const headersWithGetSetCookie = headers as Headers & { getSetCookie?: () => string[] } + const setCookieHeaders = headersWithGetSetCookie.getSetCookie?.() + if (setCookieHeaders?.length) + return setCookieHeaders + + const setCookie = headers.get('set-cookie') + return setCookie ? [setCookie] : [] +} + +const createUpstreamResponseHeaders = ( + response: Response, + requestOrigin: string | undefined | null, + allowedOrigins: DevProxyCorsAllowedOrigins, + cookieRewrite: CookieRewriteOptions | false | undefined, +) => { const headers = new Headers(response.headers) RESPONSE_HEADERS_TO_DROP.forEach(header => headers.delete(header)) + headers.delete('set-cookie') - const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie()) - rewrittenSetCookies?.forEach((cookie) => { + const setCookieHeaders = getSetCookieHeaders(response.headers) + const responseSetCookieHeaders = cookieRewrite + ? rewriteSetCookieHeadersForLocal(setCookieHeaders) + : setCookieHeaders + + responseSetCookieHeaders.forEach((cookie) => { headers.append('set-cookie', cookie) }) - applyCorsHeaders(headers, requestOrigin) + applyCorsHeaders(headers, requestOrigin, allowedOrigins) return headers } const proxyRequest = async ( context: Context, - target: string, + route: DevProxyRoute, fetchImpl: typeof globalThis.fetch, + allowedOrigins: DevProxyCorsAllowedOrigins, ) => { const requestUrl = new URL(context.req.url) - const targetUrl = buildUpstreamUrl(target, requestUrl.pathname, requestUrl.search) - const requestHeaders = createProxyRequestHeaders(context.req.raw, targetUrl) + const targetUrl = buildUpstreamUrl(route.target, requestUrl.pathname, requestUrl.search) + const requestHeaders = createProxyRequestHeaders(context.req.raw, targetUrl, route.cookieRewrite) const requestInit: RequestInit & { duplex?: 'half' } = { method: context.req.method, headers: requestHeaders, @@ -153,7 +165,12 @@ const proxyRequest = async ( } const upstreamResponse = await fetchImpl(targetUrl, requestInit) - const responseHeaders = createUpstreamResponseHeaders(upstreamResponse, context.req.header('origin')) + const responseHeaders = createUpstreamResponseHeaders( + upstreamResponse, + context.req.header('origin'), + allowedOrigins, + route.cookieRewrite, + ) return new Response(upstreamResponse.body, { status: upstreamResponse.status, @@ -162,48 +179,46 @@ const proxyRequest = async ( }) } +const normalizeRoutePaths = (paths: DevProxyRoute['paths']) => Array.isArray(paths) ? paths : [paths] + const registerProxyRoute = ( app: Hono, - path: ProxyRoutePath, - target: string, + route: DevProxyRoute, + path: string, fetchImpl: typeof globalThis.fetch, + allowedOrigins: DevProxyCorsAllowedOrigins, ) => { - app.all(path, context => proxyRequest(context, target, fetchImpl)) - app.all(`${path}/*`, context => proxyRequest(context, target, fetchImpl)) + if (!path.startsWith('/')) + throw new Error(`Invalid dev proxy route path "${path}". Paths must start with "/".`) + + app.all(path, context => proxyRequest(context, route, fetchImpl, allowedOrigins)) + app.all(`${path}/*`, context => proxyRequest(context, route, fetchImpl, allowedOrigins)) } const registerProxyRoutes = ( app: Hono, - routes: readonly ProxyRoutePath[], - target: string, + routes: readonly DevProxyRoute[], fetchImpl: typeof globalThis.fetch, + allowedOrigins: DevProxyCorsAllowedOrigins, ) => { - routes.forEach(route => registerProxyRoute(app, route, target, fetchImpl)) + routes.forEach((route) => { + normalizeRoutePaths(route.paths).forEach((path) => { + registerProxyRoute(app, route, path, fetchImpl, allowedOrigins) + }) + }) } -export const resolveDevProxyTargets = (env: DevProxyEnv = {}): DevProxyTargets => { - const consoleApiTarget = env.HONO_CONSOLE_API_PROXY_TARGET - || DEFAULT_PROXY_TARGET - const publicApiTarget = env.HONO_PUBLIC_API_PROXY_TARGET - || consoleApiTarget - const enterpriseApiTarget = env.HONO_ENTERPRISE_API_PROXY_TARGET - - return { - consoleApiTarget, - publicApiTarget, - enterpriseApiTarget, - } -} - -export const createDevProxyApp = (options: DevProxyAppOptions) => { +export const createDevProxyApp = (options: CreateDevProxyAppOptions) => { const app = new HonoApp() const fetchImpl = options.fetchImpl || globalThis.fetch + const logger = options.logger || console + const allowedOrigins = options.cors?.allowedOrigins || 'local' app.onError((error, context) => { - console.error('[dev-hono-proxy]', error) + logger.error('[dev-proxy]', error) const headers = new Headers() - applyCorsHeaders(headers, context.req.header('origin')) + applyCorsHeaders(headers, context.req.header('origin'), allowedOrigins) return new Response('Upstream proxy request failed.', { status: 502, @@ -214,7 +229,7 @@ export const createDevProxyApp = (options: DevProxyAppOptions) => { app.use('*', async (context, next) => { if (context.req.method === 'OPTIONS') { const headers = new Headers() - applyCorsHeaders(headers, context.req.header('origin')) + applyCorsHeaders(headers, context.req.header('origin'), allowedOrigins) headers.set('Access-Control-Allow-Methods', ALLOW_METHODS) headers.set( 'Access-Control-Allow-Headers', @@ -230,13 +245,10 @@ export const createDevProxyApp = (options: DevProxyAppOptions) => { } await next() - applyCorsHeaders(context.res.headers, context.req.header('origin')) + applyCorsHeaders(context.res.headers, context.req.header('origin'), allowedOrigins) }) - if (options.enterpriseApiTarget) - registerProxyRoutes(app, ENTERPRISE_API_ROUTES, options.enterpriseApiTarget, fetchImpl) - registerProxyRoutes(app, CONSOLE_API_ROUTES, options.consoleApiTarget, fetchImpl) - registerProxyRoutes(app, PUBLIC_API_ROUTES, options.publicApiTarget, fetchImpl) + registerProxyRoutes(app, options.routes, fetchImpl, allowedOrigins) return app } diff --git a/packages/dev-proxy/src/types.ts b/packages/dev-proxy/src/types.ts new file mode 100644 index 0000000000..2c42b2f7fb --- /dev/null +++ b/packages/dev-proxy/src/types.ts @@ -0,0 +1,50 @@ +export type DevProxyServerConfig = { + host?: string + port?: number +} + +export type DevProxyCorsAllowedOrigins = 'local' | readonly string[] + +export type DevProxyCorsConfig = { + allowedOrigins?: DevProxyCorsAllowedOrigins +} + +export type CookieNameMatcher = string | RegExp + +export type CookieRewriteOptions = { + hostPrefixCookies?: readonly CookieNameMatcher[] +} + +export type DevProxyRoute = { + paths: string | readonly string[] + target: string + cookieRewrite?: CookieRewriteOptions | false +} + +export type DevProxyConfig = { + server?: DevProxyServerConfig + routes: readonly DevProxyRoute[] + cors?: DevProxyCorsConfig +} + +export type DevProxyCliOptions = { + config?: string + envFile?: string + host?: string + port?: string + help?: boolean +} + +export type DevProxyConfigLoadOptions = { + envFile?: string | false +} + +export type ResolvedDevProxyServerOptions = { + host: string + port: number +} + +export type CreateDevProxyAppOptions = Pick & { + fetchImpl?: typeof globalThis.fetch + logger?: Pick +} diff --git a/packages/dev-proxy/tsconfig.json b/packages/dev-proxy/tsconfig.json new file mode 100644 index 0000000000..813a9bd8a3 --- /dev/null +++ b/packages/dev-proxy/tsconfig.json @@ -0,0 +1,17 @@ +{ + "extends": "@dify/tsconfig/node.json", + "compilerOptions": { + "types": [ + "node", + "vitest/globals" + ] + }, + "include": [ + "src/**/*.ts", + "vite.config.ts" + ], + "exclude": [ + "node_modules", + "dist" + ] +} diff --git a/packages/dev-proxy/vite.config.ts b/packages/dev-proxy/vite.config.ts new file mode 100644 index 0000000000..d060ae036e --- /dev/null +++ b/packages/dev-proxy/vite.config.ts @@ -0,0 +1,27 @@ +import { defineConfig } from 'vite-plus' + +export default defineConfig({ + pack: { + clean: true, + deps: { + neverBundle: [ + '@hono/node-server', + 'c12', + 'hono', + ], + }, + entry: [ + 'src/index.ts', + 'src/cli.ts', + ], + format: ['esm'], + outDir: 'dist', + platform: 'node', + sourcemap: true, + target: 'node22', + treeshake: true, + }, + test: { + environment: 'node', + }, +}) diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2021d87adc..4826ce8163 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -255,6 +255,9 @@ catalogs: ahooks: specifier: 3.9.7 version: 3.9.7 + c12: + specifier: 1.10.0 + version: 1.10.0 class-variance-authority: specifier: 0.7.1 version: 0.7.1 @@ -684,6 +687,37 @@ importers: specifier: 'catalog:' version: 0.1.20(@types/node@25.6.0)(@vitest/coverage-v8@4.1.5(@types/node@25.6.0)(@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.9.0)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.9.0)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3) + packages/dev-proxy: + dependencies: + '@hono/node-server': + specifier: 'catalog:' + version: 2.0.0(hono@4.12.15) + c12: + specifier: 'catalog:' + version: 1.10.0 + hono: + specifier: 'catalog:' + version: 4.12.15 + devDependencies: + '@dify/tsconfig': + specifier: workspace:* + version: link:../tsconfig + '@types/node': + specifier: 'catalog:' + version: 25.6.0 + '@typescript/native-preview': + specifier: 'catalog:' + version: 7.0.0-dev.20260428.1 + vite: + specifier: npm:@voidzero-dev/vite-plus-core@0.1.20 + version: '@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3)' + vite-plus: + specifier: 'catalog:' + version: 0.1.20(@types/node@25.6.0)(@vitest/coverage-v8@4.1.5(@types/node@25.6.0)(@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.9.0)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.9.0)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3) + vitest: + specifier: npm:@voidzero-dev/vite-plus-test@0.1.20 + version: '@voidzero-dev/vite-plus-test@0.1.20(@types/node@25.6.0)(@vitest/coverage-v8@4.1.5(@types/node@25.6.0)(@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.9.0)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(@voidzero-dev/vite-plus-core@0.1.20(@types/node@25.6.0)(esbuild@0.27.2)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.9.0)(jiti@2.6.1)(tsx@4.21.0)(typescript@6.0.3)(yaml@2.8.3)' + packages/dify-ui: dependencies: clsx: @@ -1174,15 +1208,15 @@ importers: '@eslint-react/eslint-plugin': specifier: 'catalog:' version: 3.0.0(eslint@10.2.1(jiti@2.6.1))(typescript@6.0.3) - '@hono/node-server': - specifier: 'catalog:' - version: 2.0.0(hono@4.12.15) '@iconify-json/heroicons': specifier: 'catalog:' version: 1.2.3 '@iconify-json/ri': specifier: 'catalog:' version: 1.2.10 + '@langgenius/dev-proxy': + specifier: workspace:* + version: link:../packages/dev-proxy '@langgenius/dify-ui': specifier: workspace:* version: link:../packages/dify-ui @@ -1336,9 +1370,6 @@ importers: happy-dom: specifier: 'catalog:' version: 20.9.0 - hono: - specifier: 'catalog:' - version: 4.12.15 knip: specifier: 'catalog:' version: 6.7.0(@emnapi/core@1.9.2)(@emnapi/runtime@1.9.2) @@ -4785,6 +4816,10 @@ packages: any-promise@1.3.0: resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} + anymatch@3.1.3: + resolution: {integrity: sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==} + engines: {node: '>= 8'} + are-docs-informative@0.0.2: resolution: {integrity: sha512-ixiS0nLNNG5jNQzgZJNoUpBKdo9yTYZMGJ+QgT2jmjR7G7+QHRCc4v6LQ3NgE7EBJq+o0ams3waJwkrlBom8Ig==} engines: {node: '>=14'} @@ -4847,6 +4882,10 @@ packages: engines: {node: '>=6.0.0'} hasBin: true + binary-extensions@2.3.0: + resolution: {integrity: sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==} + engines: {node: '>=8'} + birecord@0.1.1: resolution: {integrity: sha512-VUpsf/qykW0heRlC8LooCq28Kxn3mAqKohhDG/49rrsQ1dT1CXyj/pgXS+5BSRzFTR/3DyIBOqQOrGyZOh71Aw==} @@ -4897,6 +4936,9 @@ packages: resolution: {integrity: sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==} engines: {node: '>= 0.8'} + c12@1.10.0: + resolution: {integrity: sha512-0SsG7UDhoRWcuSvKWHaXmu5uNjDCDN3nkQLRL4Q42IlFy+ze58FcCoI3uPwINXinkz7ZinbhEgyzYFw9u9ZV8g==} + c12@3.3.4: resolution: {integrity: sha512-cM0ApFQSBXuourJejzwv/AuPRvAxordTyParRVcHjjtXirtkzM0uK2L9TTn9s0cXZbG7E55jCivRQzoxYmRAlA==} peerDependencies: @@ -4971,6 +5013,10 @@ packages: chevrotain@11.1.2: resolution: {integrity: sha512-opLQzEVriiH1uUQ4Kctsd49bRoFDXGGSC4GUqj7pGyxM3RehRhvTlZJc1FL/Flew2p5uwxa1tUDWKzI4wNM8pg==} + chokidar@3.6.0: + resolution: {integrity: sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==} + engines: {node: '>= 8.10.0'} + chokidar@5.0.0: resolution: {integrity: sha512-TQMmc3w+5AxjpL8iIiwebF73dRDF4fBIieAqGn9RGCWaEVwQ6Fb2cGe31Yns0RRIzii5goJ1Y7xbMwo1TxMplw==} engines: {node: '>= 20.19.0'} @@ -4998,6 +5044,9 @@ packages: resolution: {integrity: sha512-77PSwercCZU2Fc4sX94eF8k8Pxte6JAwL4/ICZLFjJLqegs7kCuAsqqj/70NQF6TvDpgFjkubQB2FW2ZZddvQg==} engines: {node: '>=8'} + citty@0.1.6: + resolution: {integrity: sha512-tskPPKEs8D2KPafUypv2gxwJP8h/OaJmC82QQGGDQcHvXX43xF2VDACcJVmZ0EuSxkpO9Kc4MlrA3q0+FG58AQ==} + class-transformer@0.5.1: resolution: {integrity: sha512-SQa1Ws6hUbfC98vKGxZH3KFY0Y1lm5Zm0SY8XX9zbK7FJCyVEac3ATW0RIpwzW+oOfmHE5PMPufDG9hCfoEOMw==} @@ -5097,6 +5146,10 @@ packages: confbox@0.2.4: resolution: {integrity: sha512-ysOGlgTFbN2/Y6Cg3Iye8YKulHw+R2fNXHrgSmXISQdMnomY6eNDprVdW9R5xBguEqI954+S6709UyiO7B+6OQ==} + consola@3.4.2: + resolution: {integrity: sha512-5IKcdX0nnYavi6G7TtOhwkYzyjfJlatbjMjuLSfE2kYT5pMDOilZ4OvMhi637CcDICTmz3wARPoyhqyX1Y+XvA==} + engines: {node: ^14.18.0 || >=16.10.0} + convert-source-map@2.0.0: resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==} @@ -6018,12 +6071,13 @@ packages: resolution: {integrity: sha512-nBF+F1rAZVCu/p7rjzgA+Yb4lfYXrpl7a6VmJrU8wF9I1CKvP/QwPNZHnOlwbTkY6dvtFIzFMSyQXbLoTQPRpA==} engines: {node: '>=8'} - get-tsconfig@4.13.7: - resolution: {integrity: sha512-7tN6rFgBlMgpBML5j8typ92BKFi2sFQvIdpAqLA2beia5avZDrMs0FLZiM5etShWq5irVyGcGMEA1jcDaK7A/Q==} - get-tsconfig@4.14.0: resolution: {integrity: sha512-yTb+8DXzDREzgvYmh6s9vHsSVCHeC0G3PI5bEXNBHtmshPnO+S5O7qgLEOn0I5QvMy6kpZN8K1NKGyilLb93wA==} + giget@1.2.5: + resolution: {integrity: sha512-r1ekGw/Bgpi3HLV3h1MRBIlSAdHoIMklpaQ3OQLFcRw9PwAj2rqigvIbg+dBUI51OxVI2jsEtDywDBjSiuf7Ug==} + hasBin: true + giget@3.2.0: resolution: {integrity: sha512-GvHTWcykIR/fP8cj8dMpuMMkvaeJfPvYnhq0oW+chSeIr+ldX21ifU2Ms6KBoyKZQZmVaUAAhQ2EZ68KJF8a7A==} hasBin: true @@ -6251,6 +6305,10 @@ packages: is-alphanumerical@2.0.1: resolution: {integrity: sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==} + is-binary-path@2.1.0: + resolution: {integrity: sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==} + engines: {node: '>=8'} + is-builtin-module@5.0.0: resolution: {integrity: sha512-f4RqJKBUe5rQkJ2eJEJBXSticB3hGbN9j0yxxMQFqIW89Jp9WYFtzfTcRlstDKVUTRzSOTLKRfO9vIztenwtxA==} engines: {node: '>=18.20'} @@ -6322,6 +6380,10 @@ packages: resolution: {integrity: sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==} engines: {node: '>=8'} + jiti@1.21.7: + resolution: {integrity: sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==} + hasBin: true + jiti@2.6.1: resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==} hasBin: true @@ -6945,6 +7007,9 @@ packages: node-addon-api@7.1.1: resolution: {integrity: sha512-5m3bsyrjFWE1xf7nz7YXdN4udnVtXK6/Yfgn5qnahL6bCkf2yKt4k3nuTKAtT4r3IG8JNR2ncsIMdZuAzJjHQQ==} + node-fetch-native@1.6.7: + resolution: {integrity: sha512-g9yhqoedzIUm0nTnTqAQvueMPVOuIY16bqgAJJC8XOOubYFNwz6IER9qs0Gq2Xd0+CecCKFjtdDTMA4u4xG06Q==} + node-releases@2.0.36: resolution: {integrity: sha512-TdC8FSgHz8Mwtw9g5L4gR/Sh9XhSP/0DEkQxfEFXOpiul5IiHgHan2VhYYb6agDSfp4KuvltmGApc8HMgUrIkA==} @@ -6952,6 +7017,10 @@ packages: resolution: {integrity: sha512-RWk+PI433eESQ7ounYxIp67CYuVsS1uYSonX3kA6ps/3LWfjVQa/ptEg6Y3T6uAMq1mWpX9PQ+qx+QaHpsc7gQ==} engines: {node: ^20.17.0 || >=22.9.0} + normalize-path@3.0.0: + resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==} + engines: {node: '>=0.10.0'} + normalize-wheel@1.0.1: resolution: {integrity: sha512-1OnlAPZ3zgrk8B91HyRj+eVv+kS5u+Z0SCsak6Xil/kmgEia50ga7zfkumayonZrImffAxPU/5WcyGhzetHNPA==} @@ -6979,6 +7048,11 @@ packages: react-router-dom: optional: true + nypm@0.5.4: + resolution: {integrity: sha512-X0SNNrZiGU8/e/zAB7sCTtdxWTMSIO73q+xuKgglm2Yvzwlo8UoC5FNySQFCvl84uPaeADkqHUZUkWy4aH4xOA==} + engines: {node: ^14.16.0 || >=16.10.0} + hasBin: true + object-assign@4.1.1: resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} engines: {node: '>=0.10.0'} @@ -6989,6 +7063,9 @@ packages: obug@2.1.1: resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} + ohash@1.1.6: + resolution: {integrity: sha512-TBu7PtV8YkAZn0tSxobKY2n2aAQva936lhRrj6957aDaCf9IEtqsKbgMzXE/F/sjqYOwmrukeORHNLe5glk7Cg==} + ohash@2.0.11: resolution: {integrity: sha512-RdR9FQrFwNBNXAr4GixM8YaRZRJ5PUWbKYbE5eOsrwAjJW0q2REGcf79oYPsLyskQCZG1PLN+S/K1V00joZAoQ==} @@ -7124,6 +7201,9 @@ packages: resolution: {integrity: sha512-+vnG6S4dYcYxZd+CZxzXCNKdELYZSKfohrk98yajCo1PtRoDgCTrrwOvK1GT0UoAdVszagDVllQc0U1vaX4NUQ==} engines: {node: '>=6'} + pathe@1.1.2: + resolution: {integrity: sha512-whLdWMYL2TwI08hn8/ZqAbrVemu0LNaNNJZX73O6qaIdCTfXutsLhMkjdENX0qhsQ9uIimo4/aQOmXkoon2nDQ==} + pathe@2.0.3: resolution: {integrity: sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==} @@ -7138,6 +7218,9 @@ packages: pend@1.2.0: resolution: {integrity: sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg==} + perfect-debounce@1.0.0: + resolution: {integrity: sha512-xCy9V055GLEqoFaHoC1SoLIaLmWctgCUaBaWxDZ7/Zx4CTyX7cJQLJOok/orfjZAh9kEYpjJa4d0KcJmCbctZA==} + perfect-debounce@2.1.0: resolution: {integrity: sha512-LjgdTytVFXeUgtHZr9WYViYSM/g8MkcTPYDlPa3cDqMirHjKiSZPYd6DoL7pK8AJQr+uWkQvCjHNdiMqsrJs+g==} @@ -7272,6 +7355,9 @@ packages: resolution: {integrity: sha512-h36JMxKRqrAxVD8201FrCpyeNuUY9Y5zZwujr20fFO77tpUtGa6EZzfKw/3WaiBX95fq7+MpsuMLNdSnORAwSA==} engines: {node: '>=14.18.0'} + rc9@2.1.2: + resolution: {integrity: sha512-btXCnMmRIBINM2LDZoEmOogIZU7Qe7zn4BpomSKZ/ykbLObuBdvG+mFq11DL6fjH1DRwHhrlgtYWG96bJiC7Cg==} + rc9@3.0.1: resolution: {integrity: sha512-gMDyleLWVE+i6Sgtc0QbbY6pEKqYs97NGi6isHQPqYlLemPoO8dxQ3uGi0f4NiP98c+jMW6cG1Kx9dDwfvqARQ==} @@ -7448,6 +7534,10 @@ packages: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} + readdirp@3.6.0: + resolution: {integrity: sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==} + engines: {node: '>=8.10.0'} + readdirp@5.0.0: resolution: {integrity: sha512-9u/XQ1pvrQtYyMpZe7DXKv2p5CNvyVwzUB6uhLAnQwHMSgKMBR62lc7AHljaeteeHXn11XTAaLLUVZYVZyuRBQ==} engines: {node: '>= 20.19.0'} @@ -7919,6 +8009,9 @@ packages: tinybench@2.9.0: resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} + tinyexec@0.3.2: + resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==} + tinyexec@1.0.4: resolution: {integrity: sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==} engines: {node: '>=18'} @@ -12097,6 +12190,11 @@ snapshots: any-promise@1.3.0: {} + anymatch@3.1.3: + dependencies: + normalize-path: 3.0.0 + picomatch: 2.3.2 + are-docs-informative@0.0.2: {} argparse@2.0.1: {} @@ -12146,6 +12244,8 @@ snapshots: baseline-browser-mapping@2.10.12: {} + binary-extensions@2.3.0: {} + birecord@0.1.1: {} birpc@4.0.0: {} @@ -12195,6 +12295,21 @@ snapshots: bytes@3.1.2: {} + c12@1.10.0: + dependencies: + chokidar: 3.6.0 + confbox: 0.1.8 + defu: 6.1.7 + dotenv: 16.6.1 + giget: 1.2.5 + jiti: 1.21.7 + mlly: 1.8.2 + ohash: 1.1.6 + pathe: 1.1.2 + perfect-debounce: 1.0.0 + pkg-types: 1.3.1 + rc9: 2.1.2 + c12@3.3.4(magicast@0.5.2): dependencies: chokidar: 5.0.0 @@ -12299,6 +12414,18 @@ snapshots: '@chevrotain/utils': 11.1.2 lodash-es: 4.18.0 + chokidar@3.6.0: + dependencies: + anymatch: 3.1.3 + braces: 3.0.3 + glob-parent: 5.1.2 + is-binary-path: 2.1.0 + is-glob: 4.0.3 + normalize-path: 3.0.0 + readdirp: 3.6.0 + optionalDependencies: + fsevents: 2.3.3 + chokidar@5.0.0: dependencies: readdirp: 5.0.0 @@ -12312,6 +12439,10 @@ snapshots: ci-info@4.4.0: {} + citty@0.1.6: + dependencies: + consola: 3.4.2 + class-transformer@0.5.1: {} class-variance-authority@0.7.1: @@ -12407,6 +12538,8 @@ snapshots: confbox@0.2.4: {} + consola@3.4.2: {} + convert-source-map@2.0.0: {} copy-to-clipboard@4.0.2: {} @@ -13513,14 +13646,20 @@ snapshots: dependencies: pump: 3.0.4 - get-tsconfig@4.13.7: - dependencies: - resolve-pkg-maps: 1.0.0 - get-tsconfig@4.14.0: dependencies: resolve-pkg-maps: 1.0.0 + giget@1.2.5: + dependencies: + citty: 0.1.6 + consola: 3.4.2 + defu: 6.1.7 + node-fetch-native: 1.6.7 + nypm: 0.5.4 + pathe: 2.0.3 + tar: 7.5.11 + giget@3.2.0: {} github-from-package@0.0.0: @@ -13820,6 +13959,10 @@ snapshots: is-alphabetical: 2.0.1 is-decimal: 2.0.1 + is-binary-path@2.1.0: + dependencies: + binary-extensions: 2.3.0 + is-builtin-module@5.0.0: dependencies: builtin-modules: 5.0.0 @@ -13874,6 +14017,8 @@ snapshots: html-escaper: 2.0.2 istanbul-lib-report: 3.0.1 + jiti@1.21.7: {} + jiti@2.6.1: {} jotai@2.19.1(@babel/core@7.29.0)(@babel/template@7.28.6)(@types/react@19.2.14)(react@19.2.5): @@ -14766,6 +14911,8 @@ snapshots: node-addon-api@7.1.1: optional: true + node-fetch-native@1.6.7: {} + node-releases@2.0.36: {} normalize-package-data@8.0.0: @@ -14774,6 +14921,8 @@ snapshots: semver: 7.7.4 validate-npm-package-license: 3.0.4 + normalize-path@3.0.0: {} + normalize-wheel@1.0.1: {} nth-check@2.1.1: @@ -14787,12 +14936,23 @@ snapshots: optionalDependencies: next: 16.2.4(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5) + nypm@0.5.4: + dependencies: + citty: 0.1.6 + consola: 3.4.2 + pathe: 2.0.3 + pkg-types: 1.3.1 + tinyexec: 0.3.2 + ufo: 1.6.3 + object-assign@4.1.1: {} object-deep-merge@2.0.0: {} obug@2.1.1: {} + ohash@1.1.6: {} + ohash@2.0.11: {} once@1.4.0: @@ -15027,6 +15187,8 @@ snapshots: path2d@0.2.2: optional: true + pathe@1.1.2: {} + pathe@2.0.3: {} pathval@2.0.1: {} @@ -15038,6 +15200,8 @@ snapshots: pend@1.2.0: {} + perfect-debounce@1.0.0: {} + perfect-debounce@2.1.0: {} picocolors@1.1.1: {} @@ -15177,6 +15341,11 @@ snapshots: radash@12.1.1: {} + rc9@2.1.2: + dependencies: + defu: 6.1.7 + destr: 2.0.5 + rc9@3.0.1: dependencies: defu: 6.1.7 @@ -15380,6 +15549,10 @@ snapshots: util-deprecate: 1.0.2 optional: true + readdirp@3.6.0: + dependencies: + picomatch: 2.3.2 + readdirp@5.0.0: {} recast@0.23.11: @@ -15968,6 +16141,8 @@ snapshots: tinybench@2.9.0: {} + tinyexec@0.3.2: {} + tinyexec@1.0.4: {} tinyglobby@0.2.16: @@ -16053,7 +16228,7 @@ snapshots: tsx@4.21.0: dependencies: esbuild: 0.27.2 - get-tsconfig: 4.13.7 + get-tsconfig: 4.14.0 optionalDependencies: fsevents: 2.3.3 @@ -16609,15 +16784,18 @@ time: '@typescript/native-preview@7.0.0-dev.20260428.1': '2026-04-28T08:09:51.266Z' '@voidzero-dev/vite-plus-core@0.1.20': '2026-04-29T03:08:39.629Z' '@voidzero-dev/vite-plus-test@0.1.20': '2026-04-29T03:08:45.501Z' + c12@1.10.0: '2024-03-06T13:11:04.381Z' concurrently@9.2.1: '2025-08-25T09:50:49.138Z' copy-to-clipboard@4.0.2: '2026-04-24T22:15:18.933Z' eslint-markdown@0.7.0: '2026-04-25T11:31:20.226Z' eslint-plugin-better-tailwindcss@4.5.0: '2026-04-28T06:24:47.281Z' eslint@10.2.1: '2026-04-17T20:17:44.852Z' + hono@4.12.15: '2026-04-24T06:51:10.290Z' i18next@26.0.8: '2026-04-24T19:20:14.685Z' js-yaml@4.1.1: '2025-11-12T15:18:03.524Z' lexical@0.44.0: '2026-04-27T14:47:00.970Z' tldts@7.0.29: '2026-04-28T12:21:32.710Z' + tsx@4.21.0: '2025-11-30T15:56:09.488Z' typescript@6.0.3: '2026-04-16T23:38:27.905Z' uuid@14.0.0: '2026-04-19T15:15:42.302Z' vinext@0.0.45: '2026-04-28T11:43:03.463Z' diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index b0c007ee4d..ee6ccf00df 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -138,6 +138,7 @@ catalog: abcjs: 6.6.3 agentation: 3.0.2 ahooks: 3.9.7 + c12: 1.10.0 class-variance-authority: 0.7.1 client-only: 0.0.1 clsx: 2.1.1 diff --git a/web/.env.example b/web/.env.example index ef3ecd8101..81fff4275d 100644 --- a/web/.env.example +++ b/web/.env.example @@ -17,18 +17,11 @@ NEXT_PUBLIC_COOKIE_DOMAIN= # WebSocket server URL. NEXT_PUBLIC_SOCKET_URL=ws://localhost:5001 -# Dev-only Hono proxy targets. -# The frontend keeps requesting http://localhost:5001 directly, -# the proxy server will forward the request to the target server, -# so that you don't need to run a separate backend server and use online API in development. -# Supported values: dify, enterprise. -# Defaults to dify. Enterprise target listens on port 8082 by default. -HONO_PROXY_TARGET=dify -HONO_PROXY_HOST=127.0.0.1 -HONO_PROXY_PORT= -HONO_CONSOLE_API_PROXY_TARGET= -HONO_PUBLIC_API_PROXY_TARGET= -HONO_ENTERPRISE_API_PROXY_TARGET= +# Dev proxy routes are configured in web/dev-proxy.config.ts. +# pnpm -C web run dev:proxy loads web/.env.local before evaluating that config file. +DEV_PROXY_TARGET=https://cloud.dify.ai +DEV_PROXY_HOST=127.0.0.1 +DEV_PROXY_PORT=5001 # The API PREFIX for MARKETPLACE NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1 diff --git a/web/.storybook/main.ts b/web/.storybook/main.ts index 918860c786..e5bf0ee65e 100644 --- a/web/.storybook/main.ts +++ b/web/.storybook/main.ts @@ -1,7 +1,10 @@ import type { StorybookConfig } from '@storybook/nextjs-vite' const config: StorybookConfig = { - stories: ['../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)'], + stories: [ + '../app/components/**/*.stories.@(js|jsx|mjs|ts|tsx)', + '../features/**/*.stories.@(js|jsx|mjs|ts|tsx)', + ], addons: [ // Not working with Storybook Vite framework // '@storybook/addon-onboarding', diff --git a/web/README.md b/web/README.md index 206541eab6..1748ed6947 100644 --- a/web/README.md +++ b/web/README.md @@ -56,9 +56,9 @@ pnpm -C web run dev # or if you are using vinext which provides a better development experience pnpm -C web run dev:vinext # (optional) start the dev proxy server so that you can use online API in development +# edit web/dev-proxy.config.ts to choose proxy paths +# edit web/.env.local to override DEV_PROXY_TARGET, DEV_PROXY_ENTERPRISE_TARGET, DEV_PROXY_HOST, or DEV_PROXY_PORT pnpm -C web run dev:proxy -# (optional) start the dev proxy for the Enterprise frontend; it listens on 8082 by default -pnpm -C web run dev:proxy -- --target enterprise ``` Open with your browser to see the result. diff --git a/web/__mocks__/provider-context.ts b/web/__mocks__/provider-context.ts index d3296bacd0..10fac8d8b6 100644 --- a/web/__mocks__/provider-context.ts +++ b/web/__mocks__/provider-context.ts @@ -13,6 +13,7 @@ export const baseProviderContextValue: ProviderContextState = { isAPIKeySet: true, plan: defaultPlan, isFetchedPlan: false, + isFetchedPlanInfo: false, enableBilling: false, onPlanInfoChanged: noop, enableReplaceWebAppLogo: false, diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 8a1a6fd131..46d7f7833e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -21,20 +21,14 @@ import { useShallow } from 'zustand/react/shallow' import AppSideBar from '@/app/components/app-sidebar' import { useStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' -import dynamic from '@/next/dynamic' import { usePathname, useRouter } from '@/next/navigation' import { fetchAppDetailDirect } from '@/service/apps' import { AppModeEnum } from '@/types/app' import s from './style.module.css' -const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), { - ssr: false, -}) - type IAppDetailLayoutProps = { children: React.ReactNode appId: string @@ -56,7 +50,6 @@ const AppDetailLayout: FC = (props) => { setAppDetail: state.setAppDetail, setAppSidebarExpand: state.setAppSidebarExpand, }))) - const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [isLoadingAppDetail, setIsLoadingAppDetail] = useState(false) const [appDetailRes, setAppDetailRes] = useState(null) const [navigation, setNavigation] = useState = (props) => {
{children}
- {showTagManagementModal && ( - - )} ) } diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index 44ba5ee8ad..82e47d5c0b 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -1,10 +1,8 @@ 'use client' -import { - useEffect, - useMemo, -} from 'react' +import { useEffect } from 'react' import EducationApplyPage from '@/app/education-apply/education-apply-page' +import RootLoading from '@/app/loading' import { useProviderContext } from '@/context/provider-context' import { useRouter, @@ -13,17 +11,24 @@ import { export default function EducationApply() { const router = useRouter() - const { enableEducationPlan } = useProviderContext() + const { + enableEducationPlan, + isFetchedPlanInfo, + isLoadingEducationAccountInfo, + } = useProviderContext() const searchParams = useSearchParams() const token = searchParams.get('token') - const showEducationApplyPage = useMemo(() => { - return enableEducationPlan && token - }, [enableEducationPlan, token]) useEffect(() => { - if (!showEducationApplyPage) + if (!isFetchedPlanInfo) + return + + if (!enableEducationPlan || !token) router.replace('/') - }, [showEducationApplyPage, router]) + }, [enableEducationPlan, isFetchedPlanInfo, router, token]) + + if (!isFetchedPlanInfo || !enableEducationPlan || !token || isLoadingEducationAccountInfo) + return return } diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index f3bb71c2d2..83bca6a8cb 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -3,8 +3,7 @@ import { Button } from '@langgenius/dify-ui/button' import { Dialog, DialogContent } from '@langgenius/dify-ui/dialog' import { toast } from '@langgenius/dify-ui/toast' import { RiCloseLine } from '@remixicon/react' -import * as React from 'react' -import { useState } from 'react' +import { useCallback, useEffect, useRef, useState } from 'react' import { Trans, useTranslation } from 'react-i18next' import Input from '@/app/components/base/input' import { useRouter } from '@/next/navigation' @@ -18,22 +17,23 @@ import { useLogout } from '@/service/use-common' import { asyncRunSafe } from '@/utils' type Props = { - show: boolean onClose: () => void email: string } -enum STEP { - start = 'start', - verifyOrigin = 'verifyOrigin', - newEmail = 'newEmail', - verifyNew = 'verifyNew', -} +const STEP = { + start: 'start', + verifyOrigin: 'verifyOrigin', + newEmail: 'newEmail', + verifyNew: 'verifyNew', +} as const -const EmailChangeModal = ({ onClose, email, show }: Props) => { +type Step = typeof STEP[keyof typeof STEP] + +const EmailChangeModal = ({ onClose, email }: Props) => { const { t } = useTranslation() const router = useRouter() - const [step, setStep] = useState(STEP.start) + const [step, setStep] = useState(STEP.start) const [code, setCode] = useState('') const [mail, setMail] = useState('') const [time, setTime] = useState(0) @@ -41,13 +41,25 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { const [newEmailExited, setNewEmailExited] = useState(false) const [unAvailableEmail, setUnAvailableEmail] = useState(false) const [isCheckingEmail, setIsCheckingEmail] = useState(false) + const timerRef = useRef | null>(null) + + const clearCountdown = useCallback(() => { + if (!timerRef.current) + return + + clearInterval(timerRef.current) + timerRef.current = null + }, []) + + useEffect(() => clearCountdown, [clearCountdown]) const startCount = () => { + clearCountdown() setTime(60) - const timer = setInterval(() => { + timerRef.current = setInterval(() => { setTime((prev) => { - if (prev <= 0) { - clearInterval(timer) + if (prev <= 1) { + clearCountdown() return 0 } return prev - 1 @@ -181,7 +193,7 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { } return ( - !open && onClose()}> + !open && onClose()}>
diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index 2a4ae86f84..0de33a2a71 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -332,11 +332,15 @@ export default function AccountPage() { /> ) } - setShowUpdateEmail(false)} - email={userProfile.email} - /> + {/* Use conditional JSX instead of a mounted controlled Dialog so closing destroys the email-change form session. */} + {showUpdateEmail + ? ( + setShowUpdateEmail(false)} + email={userProfile.email} + /> + ) + : null} ) } diff --git a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx index 5d3c008989..1be3799480 100644 --- a/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx +++ b/web/app/components/app/overview/apikey-info-panel/apikey-info-panel.test-utils.tsx @@ -31,6 +31,7 @@ const defaultProviderContext = { isAPIKeySet: false, plan: defaultPlan, isFetchedPlan: false, + isFetchedPlanInfo: false, enableBilling: false, onPlanInfoChanged: noop, enableReplaceWebAppLogo: false, diff --git a/web/app/components/apps/__tests__/app-card.spec.tsx b/web/app/components/apps/__tests__/app-card.spec.tsx index 6a71dbac52..4edf5604da 100644 --- a/web/app/components/apps/__tests__/app-card.spec.tsx +++ b/web/app/components/apps/__tests__/app-card.spec.tsx @@ -301,9 +301,9 @@ vi.mock('@/app/components/base/tooltip', () => ({ default: ({ children, popupContent }: { children: React.ReactNode, popupContent: React.ReactNode }) => React.createElement('div', { title: popupContent }, children), })) -// TagSelector has API dependency (service/tag) - mock for isolated testing -vi.mock('@/app/components/base/tag-management/selector', () => ({ - default: ({ tags }: { tags?: { id: string, name: string }[] }) => { +// AppCardTags has tag API dependencies - mock for isolated testing +vi.mock('@/features/tag-management/components/app-card-tags', () => ({ + AppCardTags: ({ tags }: { tags?: { id: string, name: string }[] }) => { return React.createElement('div', { 'aria-label': 'tag-selector' }, tags?.map((tag: { id: string, name: string }) => React.createElement('span', { key: tag.id }, tag.name))) }, })) @@ -400,13 +400,30 @@ describe('AppCard', () => { it('should handle app with tags', () => { const appWithTags = { ...mockApp, - tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app' as const, binding_count: 0 }], } render() // Verify the tag selector component renders expect(screen.getByLabelText('tag-selector')).toBeInTheDocument() }) + it('should display refreshed tag names from app props when tag ids stay the same', () => { + const firstApp = createMockApp({ + tags: [{ id: 'tag1', name: 'Old Tag', type: 'app' as const, binding_count: 0 }], + }) + const refreshedApp = createMockApp({ + tags: [{ id: 'tag1', name: 'New Tag', type: 'app' as const, binding_count: 0 }], + }) + + const { rerender } = render() + expect(screen.getByText('Old Tag')).toBeInTheDocument() + + rerender() + + expect(screen.getByText('New Tag')).toBeInTheDocument() + expect(screen.queryByText('Old Tag')).not.toBeInTheDocument() + }) + it('should render with onRefresh callback', () => { render() expect(screen.getByTitle('Test App')).toBeInTheDocument() @@ -1167,9 +1184,9 @@ describe('AppCard', () => { const multiTagApp = { ...mockApp, tags: [ - { id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }, - { id: 'tag2', name: 'Tag 2', type: 'app', binding_count: 0 }, - { id: 'tag3', name: 'Tag 3', type: 'app', binding_count: 0 }, + { id: 'tag1', name: 'Tag 1', type: 'app' as const, binding_count: 0 }, + { id: 'tag2', name: 'Tag 2', type: 'app' as const, binding_count: 0 }, + { id: 'tag3', name: 'Tag 3', type: 'app' as const, binding_count: 0 }, ], } render() @@ -1324,7 +1341,7 @@ describe('AppCard', () => { it('should stop propagation when clicking tag selector area', () => { const multiTagApp = createMockApp({ - tags: [{ id: 'tag1', name: 'Tag 1', type: 'app', binding_count: 0 }], + tags: [{ id: 'tag1', name: 'Tag 1', type: 'app' as const, binding_count: 0 }], }) render() diff --git a/web/app/components/apps/__tests__/list.spec.tsx b/web/app/components/apps/__tests__/list.spec.tsx index 9d1b39ef06..41d2ccbc80 100644 --- a/web/app/components/apps/__tests__/list.spec.tsx +++ b/web/app/components/apps/__tests__/list.spec.tsx @@ -1,7 +1,6 @@ import { act, fireEvent, screen } from '@testing-library/react' import * as React from 'react' import { createSystemFeaturesWrapper } from '@/__tests__/utils/mock-system-features' -import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' @@ -29,6 +28,11 @@ vi.mock('@/service/client', () => ({ infiniteOptions: (options: unknown) => mockAppListInfiniteOptions(options), }, }, + tags: { + list: { + queryOptions: (options: unknown) => options, + }, + }, systemFeatures: { queryKey: () => ['console', 'systemFeatures'], }, @@ -139,10 +143,6 @@ vi.mock('@/service/use-apps', () => ({ }), })) -vi.mock('@/service/tag', () => ({ - fetchTagList: vi.fn().mockResolvedValue([{ id: 'tag-1', name: 'Test Tag', type: 'app' }]), -})) - vi.mock('@/config', async (importOriginal) => { const actual = await importOriginal() return { @@ -236,10 +236,6 @@ type AppListInfiniteOptions = { describe('List', () => { beforeEach(() => { vi.clearAllMocks() - useTagStore.setState({ - tagList: [{ id: 'tag-1', name: 'Test Tag', type: 'app', binding_count: 0 }], - showTagManagementModal: false, - }) mockIsCurrentWorkspaceEditor.mockReturnValue(true) mockIsCurrentWorkspaceDatasetOperator.mockReturnValue(false) mockDragging = false diff --git a/web/app/components/apps/app-card.tsx b/web/app/components/apps/app-card.tsx index 458c7578c7..06c5c8a9d8 100644 --- a/web/app/components/apps/app-card.tsx +++ b/web/app/components/apps/app-card.tsx @@ -1,7 +1,6 @@ 'use client' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' -import type { Tag } from '@/app/components/base/tag-management/constant' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import type { EnvironmentVariable } from '@/app/components/workflow/types' import type { WorkflowOnlineUser } from '@/models/app' @@ -36,11 +35,11 @@ import { Trans, useTranslation } from 'react-i18next' import { AppTypeIcon } from '@/app/components/app/type-selector' import AppIcon from '@/app/components/base/app-icon' import Input from '@/app/components/base/input' -import TagSelector from '@/app/components/base/tag-management/selector' import { UserAvatarList } from '@/app/components/base/user-avatar-list' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' +import { AppCardTags } from '@/features/tag-management/components/app-card-tags' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { AccessMode } from '@/models/access-control' import dynamic from '@/next/dynamic' @@ -77,6 +76,7 @@ type AppCardProps = { app: App onlineUsers?: WorkflowOnlineUser[] onRefresh?: () => void + onOpenTagManagement?: () => void } type AppCardOperationsMenuProps = { @@ -207,7 +207,7 @@ const AppCardOperationsMenuContent: React.FC ) } -const AppCard = ({ app, onlineUsers = [], onRefresh }: AppCardProps) => { +const AppCard = ({ app, onlineUsers = [], onRefresh, onOpenTagManagement = () => {} }: AppCardProps) => { const { t } = useTranslation() const deleteAppNameInputId = useId() const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) @@ -396,19 +396,6 @@ const AppCard = ({ app, onlineUsers = [], onRefresh }: AppCardProps) => { const shouldShowAccessControlOption = systemFeatures.webapp_auth.enabled && isCurrentWorkspaceEditor const operationsMenuWidthClassName = shouldShowSwitchOption ? 'w-[256px]' : 'w-[216px]' - const appTagsKey = useMemo(() => app.tags.map(tag => tag.id).join(','), [app.tags]) - const [tagState, setTagState] = useState<{ key: string, tags: Tag[] }>(() => ({ - key: appTagsKey, - tags: app.tags, - })) - const tags = tagState.key === appTagsKey ? tagState.tags : app.tags - const handleTagsUpdate = useCallback((nextTags: Tag[]) => { - setTagState({ - key: appTagsKey, - tags: nextTags, - }) - }, [appTagsKey]) - const EditTimeText = useMemo(() => { const timeText = formatTime({ date: (app.updated_at || app.created_at) * 1000, @@ -523,15 +510,12 @@ const AppCard = ({ app, onlineUsers = [], onRefresh }: AppCardProps) => { e.preventDefault() }} > -
- tag.id)} - selectedTags={tags} - onCacheUpdate={handleTagsUpdate} - onChange={onRefresh} +
+
diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 728ef38ba5..0fd31dfb79 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -11,10 +11,9 @@ import { useTranslation } from 'react-i18next' import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' -import TagFilter from '@/app/components/base/tag-management/filter' -import { useStore as useTagStore } from '@/app/components/base/tag-management/store' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' +import { TagFilter } from '@/features/tag-management/components/tag-filter' import { CheckModal } from '@/hooks/use-pay' import dynamic from '@/next/dynamic' import { consoleQuery } from '@/service/client' @@ -24,12 +23,12 @@ import AppCard from './app-card' import { AppCardSkeleton } from './app-card-skeleton' import Empty from './empty' import Footer from './footer' -import useAppsQueryState from './hooks/use-apps-query-state' +import useAppsQueryStateHook from './hooks/use-apps-query-state' import { useDSLDragDrop } from './hooks/use-dsl-drag-drop' import { useWorkflowOnlineUsers } from './hooks/use-workflow-online-users' import NewAppCard from './new-app-card' -const TagManagementModal = dynamic(() => import('@/app/components/base/tag-management'), { +const TagManagementModal = dynamic(() => import('@/features/tag-management/components/tag-management-modal').then(mod => mod.TagManagementModal), { ssr: false, }) const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-from-dsl-modal'), { @@ -57,18 +56,20 @@ const List: FC = ({ const { t } = useTranslation() const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions()) const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, isLoadingCurrentWorkspace } = useAppContext() - const showTagManagementModal = useTagStore(s => s.showTagManagementModal) const [activeTab, setActiveTab] = useQueryState( 'category', parseAsAppListCategory, ) - const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = useAppsQueryState() + // eslint-disable-next-line react/use-state -- custom URL query hook, not React.useState + const appsQuery = useAppsQueryStateHook() + const { query: { tagIDs = [], keywords = '', isCreatedByMe: queryIsCreatedByMe = false }, setQuery } = appsQuery const [isCreatedByMe, setIsCreatedByMe] = useState(queryIsCreatedByMe) const [tagFilterValue, setTagFilterValue] = useState(tagIDs) const [searchKeywords, setSearchKeywords] = useState(keywords) const newAppCardRef = useRef(null) const containerRef = useRef(null) + const [showTagManagementModal, setShowTagManagementModal] = useState(false) const [showCreateFromDSLModal, setShowCreateFromDSLModal] = useState(false) const [droppedDSLFile, setDroppedDSLFile] = useState() const setKeywords = useCallback((keywords: string) => { @@ -245,7 +246,7 @@ const List: FC = ({ {t('showMyCreatedAppsOnly', { ns: 'app' })}
- + setShowTagManagementModal(true)} /> = ({ app={app} onlineUsers={workflowOnlineUsersMap[app.id] ?? []} onRefresh={refetch} + onOpenTagManagement={() => setShowTagManagementModal(true)} /> )) : } @@ -302,9 +304,12 @@ const List: FC = ({ )}
- {showTagManagementModal && ( - - )} + setShowTagManagementModal(false)} + onTagsChange={refetch} + /> {showCreateFromDSLModal && ( diff --git a/web/app/components/base/tag-management/__tests__/tag-remove-modal.spec.tsx b/web/app/components/base/tag-management/__tests__/tag-remove-modal.spec.tsx deleted file mode 100644 index 943b7bc8ff..0000000000 --- a/web/app/components/base/tag-management/__tests__/tag-remove-modal.spec.tsx +++ /dev/null @@ -1,123 +0,0 @@ -import type { Tag } from '../constant' -import { render, screen } from '@testing-library/react' -import userEvent from '@testing-library/user-event' -import TagRemoveModal from '../tag-remove-modal' - -const mockTag: Tag = { - id: 'tag-1', - name: 'Frontend', - type: 'app', - binding_count: 3, -} - -describe('TagRemoveModal', () => { - beforeEach(() => { - vi.clearAllMocks() - }) - - // Rendering behavior and visibility control. - describe('Rendering', () => { - it('should render modal content when show is true', () => { - render( - , - ) - - expect(screen.getByText('common.tag.delete')).toBeInTheDocument() - expect(screen.getByText('"Frontend"')).toBeInTheDocument() - expect(screen.getByText('common.tag.deleteTip')).toBeInTheDocument() - expect(screen.getByText('common.operation.cancel')).toBeInTheDocument() - expect(screen.getByText('common.operation.delete')).toBeInTheDocument() - }) - - it('should not render modal content when show is false', () => { - render( - , - ) - - expect(screen.queryByText('common.tag.delete')).not.toBeInTheDocument() - expect(screen.queryByText('common.tag.deleteTip')).not.toBeInTheDocument() - }) - }) - - // User interactions for closing and confirming actions. - describe('User Interactions', () => { - it('should call onClose when top-right close icon is clicked', async () => { - const user = userEvent.setup() - const onClose = vi.fn() - render( - , - ) - - const closeIconButton = screen.getByTestId('tag-remove-modal-close-button') - expect(closeIconButton).toBeInTheDocument() - await user.click(closeIconButton) - - expect(onClose).toHaveBeenCalledTimes(1) - }) - - it('should call onClose when cancel button is clicked', async () => { - const user = userEvent.setup() - const onClose = vi.fn() - - render( - , - ) - - await user.click(screen.getByText('common.operation.cancel')) - expect(onClose).toHaveBeenCalledTimes(1) - }) - - it('should call onConfirm when delete button is clicked', async () => { - const user = userEvent.setup() - const onConfirm = vi.fn() - - render( - , - ) - - await user.click(screen.getByText('common.operation.delete')) - expect(onConfirm).toHaveBeenCalledTimes(1) - }) - }) - - // Edge case for unusual tag names in the title. - describe('Edge Cases', () => { - it('should render quoted empty tag name safely', () => { - render( - , - ) - - expect(screen.getByText('""')).toBeInTheDocument() - }) - }) -}) diff --git a/web/app/components/base/tag-management/constant.ts b/web/app/components/base/tag-management/constant.ts deleted file mode 100644 index 3c60041383..0000000000 --- a/web/app/components/base/tag-management/constant.ts +++ /dev/null @@ -1,6 +0,0 @@ -export type Tag = { - id: string - name: string - type: string - binding_count: number -} diff --git a/web/app/components/base/tag-management/index.tsx b/web/app/components/base/tag-management/index.tsx deleted file mode 100644 index 19fbfcb7c9..0000000000 --- a/web/app/components/base/tag-management/index.tsx +++ /dev/null @@ -1,62 +0,0 @@ -'use client' -import { toast } from '@langgenius/dify-ui/toast' -import { useEffect, useState } from 'react' -import { useTranslation } from 'react-i18next' -import Modal from '@/app/components/base/modal' -import { createTag, fetchTagList } from '@/service/tag' -import { useStore as useTagStore } from './store' -import TagItemEditor from './tag-item-editor' - -type TagManagementModalProps = { - type: 'knowledge' | 'app' - show: boolean -} -const TagManagementModal = ({ show, type }: TagManagementModalProps) => { - const { t } = useTranslation() - const tagList = useTagStore(s => s.tagList) - const setTagList = useTagStore(s => s.setTagList) - const setShowTagManagementModal = useTagStore(s => s.setShowTagManagementModal) - const getTagList = async (type: 'knowledge' | 'app') => { - const res = await fetchTagList(type) - setTagList(res) - } - const [pending, setPending] = useState(false) - const [name, setName] = useState('') - const createNewTag = async () => { - if (!name) - return - if (pending) - return - try { - setPending(true) - const newTag = await createTag(name, type) - toast.success(t('tag.created', { ns: 'common' })) - setTagList([ - newTag, - ...tagList, - ]) - setName('') - setPending(false) - } - catch { - toast.error(t('tag.failed', { ns: 'common' })) - setPending(false) - } - } - useEffect(() => { - getTagList(type) - }, [type]) - return ( - setShowTagManagementModal(false)}> -
{t('tag.manageTags', { ns: 'common' })}
-
setShowTagManagementModal(false)}> - -
-
- setName(e.target.value)} onKeyDown={e => e.key === 'Enter' && !e.nativeEvent.isComposing && createNewTag()} onBlur={createNewTag} /> - {tagList.map(tag => ())} -
-
- ) -} -export default TagManagementModal diff --git a/web/app/components/base/tag-management/selector.tsx b/web/app/components/base/tag-management/selector.tsx deleted file mode 100644 index 0eb233ba4b..0000000000 --- a/web/app/components/base/tag-management/selector.tsx +++ /dev/null @@ -1,116 +0,0 @@ -import type { FC } from 'react' -import type { Tag } from '@/app/components/base/tag-management/constant' -import { cn } from '@langgenius/dify-ui/cn' -import { - Popover, - PopoverContent, - PopoverTrigger, -} from '@langgenius/dify-ui/popover' -import { useCallback, useMemo, useState } from 'react' -import { useTranslation } from 'react-i18next' -import { fetchTagList } from '@/service/tag' -import Panel from './panel' -import { useStore as useTagStore } from './store' -import Trigger from './trigger' - -export type TagSelectorProps = { - targetID: string - isPopover?: boolean - position?: 'bl' | 'br' - type: 'knowledge' | 'app' - value: string[] - selectedTags: Tag[] - onCacheUpdate: (tags: Tag[]) => void - onChange?: () => void - minWidth?: number | string -} - -const TagSelector: FC = ({ - targetID, - isPopover = true, - position, - type, - value, - selectedTags, - onCacheUpdate, - onChange, - minWidth, -}) => { - const { t } = useTranslation() - const tagList = useTagStore(s => s.tagList) - const setTagList = useTagStore(s => s.setTagList) - const [open, setOpen] = useState(false) - - const getTagList = useCallback(async () => { - const res = await fetchTagList(type) - setTagList(res) - }, [setTagList, type]) - - const tags = useMemo(() => { - if (selectedTags?.length) - return selectedTags.filter(selectedTag => tagList.find(tag => tag.id === selectedTag.id)).map(tag => tag.name) - return [] - }, [selectedTags, tagList]) - - const placement = useMemo(() => { - if (position === 'bl') - return 'bottom-start' as const - if (position === 'br') - return 'bottom-end' as const - return 'bottom' as const - }, [position]) - - const resolvedMinWidth = useMemo(() => { - if (minWidth == null) - return undefined - - return typeof minWidth === 'number' ? `${minWidth}px` : minWidth - }, [minWidth]) - - const triggerLabel = useMemo(() => { - if (tags.length) - return tags.join(', ') - - return t('tag.addTag', { ns: 'common' }) - }, [tags, t]) - - if (!isPopover) - return null - - return ( - - - - - - - - - ) -} - -export default TagSelector diff --git a/web/app/components/base/tag-management/store.ts b/web/app/components/base/tag-management/store.ts deleted file mode 100644 index 197d31ed7a..0000000000 --- a/web/app/components/base/tag-management/store.ts +++ /dev/null @@ -1,19 +0,0 @@ -import type { Tag } from './constant' -import { create } from 'zustand' - -type State = { - tagList: Tag[] - showTagManagementModal: boolean -} - -type Action = { - setTagList: (tagList?: Tag[]) => void - setShowTagManagementModal: (showTagManagementModal: boolean) => void -} - -export const useStore = create(set => ({ - tagList: [], - setTagList: tagList => set(() => ({ tagList })), - showTagManagementModal: false, - setShowTagManagementModal: showTagManagementModal => set(() => ({ showTagManagementModal })), -})) diff --git a/web/app/components/base/tag-management/tag-remove-modal.tsx b/web/app/components/base/tag-management/tag-remove-modal.tsx deleted file mode 100644 index 1088ca2043..0000000000 --- a/web/app/components/base/tag-management/tag-remove-modal.tsx +++ /dev/null @@ -1,48 +0,0 @@ -'use client' - -import type { Tag } from '@/app/components/base/tag-management/constant' -import { Button } from '@langgenius/dify-ui/button' -import { cn } from '@langgenius/dify-ui/cn' -import { noop } from 'es-toolkit/function' -import { useTranslation } from 'react-i18next' -import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import Modal from '@/app/components/base/modal' - -type TagRemoveModalProps = { - show: boolean - tag: Tag - onConfirm: () => void - onClose: () => void -} - -const TagRemoveModal = ({ show, tag, onConfirm, onClose }: TagRemoveModalProps) => { - const { t } = useTranslation() - - return ( - -
- -
-
- -
-
- {`${t('tag.delete', { ns: 'common' })} `} - {`"${tag.name}"`} -
-
- {t('tag.deleteTip', { ns: 'common' })} -
-
- - -
-
- ) -} - -export default TagRemoveModal diff --git a/web/app/components/billing/hooks/use-education-discount.ts b/web/app/components/billing/hooks/use-education-discount.ts new file mode 100644 index 0000000000..dedad4707e --- /dev/null +++ b/web/app/components/billing/hooks/use-education-discount.ts @@ -0,0 +1,37 @@ +'use client' +import { toast } from '@langgenius/dify-ui/toast' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useAppContext } from '@/context/app-context' +import { fetchSubscriptionUrls } from '@/service/billing' +import { Plan } from '../type' + +export const useEducationDiscount = () => { + const { t } = useTranslation() + const { isCurrentWorkspaceManager } = useAppContext() + const [isEducationDiscountLoading, setIsEducationDiscountLoading] = useState(false) + + const handleEducationDiscount = useCallback(async () => { + if (isEducationDiscountLoading) + return + + if (!isCurrentWorkspaceManager) { + toast.error(t('buyPermissionDeniedTip', { ns: 'billing' })) + return + } + + setIsEducationDiscountLoading(true) + try { + const res = await fetchSubscriptionUrls(Plan.professional, 'year') + window.location.href = res.url + } + finally { + setIsEducationDiscountLoading(false) + } + }, [isCurrentWorkspaceManager, isEducationDiscountLoading, t]) + + return { + handleEducationDiscount, + isEducationDiscountLoading, + } +} diff --git a/web/app/components/billing/plan/__tests__/index.spec.tsx b/web/app/components/billing/plan/__tests__/index.spec.tsx index 27f6b3005d..e9e0fd7012 100644 --- a/web/app/components/billing/plan/__tests__/index.spec.tsx +++ b/web/app/components/billing/plan/__tests__/index.spec.tsx @@ -1,11 +1,15 @@ import { act, fireEvent, render, screen, waitFor } from '@testing-library/react' import { EDUCATION_VERIFYING_LOCALSTORAGE_ITEM } from '@/app/education-apply/constants' +import { fetchSubscriptionUrls } from '@/service/billing' import { Plan, SelfHostedPlan } from '../../type' import PlanComp from '../index' let currentPath = '/billing' const push = vi.fn() +let isCurrentWorkspaceManager = true +let assignedHref = '' +const originalLocation = window.location vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push }), @@ -27,10 +31,16 @@ vi.mock('@/context/provider-context', () => ({ vi.mock('@/context/app-context', () => ({ useAppContext: () => ({ userProfile: { email: 'user@example.com' }, - isCurrentWorkspaceManager: true, + isCurrentWorkspaceManager, }), })) +vi.mock('@/service/billing', () => ({ + fetchSubscriptionUrls: vi.fn(), +})) + +const fetchSubscriptionUrlsMock = vi.mocked(fetchSubscriptionUrls) + const mutateAsyncMock = vi.fn() let isPending = false vi.mock('@/service/use-education', () => ({ @@ -78,10 +88,26 @@ describe('PlanComp', () => { }, } + beforeAll(() => { + Object.defineProperty(window, 'location', { + configurable: true, + value: { + get href() { + return assignedHref + }, + set href(value: string) { + assignedHref = value + }, + } as unknown as Location, + }) + }) + beforeEach(() => { vi.clearAllMocks() currentPath = '/billing' isPending = false + isCurrentWorkspaceManager = true + assignedHref = '' providerContextMock.mockReturnValue({ plan: planMock, enableEducationPlan: true, @@ -90,6 +116,14 @@ describe('PlanComp', () => { }) mutateAsyncMock.mockReset() mutateAsyncMock.mockResolvedValue({ token: 'token' }) + fetchSubscriptionUrlsMock.mockResolvedValue({ url: 'https://subscription.example' }) + }) + + afterAll(() => { + Object.defineProperty(window, 'location', { + configurable: true, + value: originalLocation, + }) }) it('renders plan info and handles education verify success', async () => { @@ -170,6 +204,49 @@ describe('PlanComp', () => { expect(screen.getByText('education.toVerified'))!.toBeInTheDocument() }) + it('shows education discount button and keeps upgrade button for education accounts', async () => { + providerContextMock.mockReturnValue({ + plan: { ...planMock, type: Plan.sandbox }, + enableEducationPlan: true, + allowRefreshEducationVerify: false, + isEducationAccount: true, + }) + render() + + fireEvent.click(screen.getByText('education.useEducationDiscount')) + + await waitFor(() => { + expect(fetchSubscriptionUrlsMock).toHaveBeenCalledWith(Plan.professional, 'year') + expect(assignedHref).toBe('https://subscription.example') + }) + expect(screen.getByTestId('plan-upgrade-btn'))!.toBeInTheDocument() + }) + + it('does not show education discount button for non-sandbox education accounts', () => { + providerContextMock.mockReturnValue({ + plan: planMock, + enableEducationPlan: true, + allowRefreshEducationVerify: false, + isEducationAccount: true, + }) + render() + + expect(screen.queryByText('education.useEducationDiscount')).not.toBeInTheDocument() + }) + + it('does not show education discount button for non-manager sandbox education accounts', () => { + isCurrentWorkspaceManager = false + providerContextMock.mockReturnValue({ + plan: { ...planMock, type: Plan.sandbox }, + enableEducationPlan: true, + allowRefreshEducationVerify: false, + isEducationAccount: true, + }) + render() + + expect(screen.queryByText('education.useEducationDiscount')).not.toBeInTheDocument() + }) + it('renders enterprise plan without upgrade button', () => { providerContextMock.mockReturnValue({ plan: { ...planMock, type: SelfHostedPlan.enterprise }, diff --git a/web/app/components/billing/plan/index.tsx b/web/app/components/billing/plan/index.tsx index 49d4ffa779..498736475c 100644 --- a/web/app/components/billing/plan/index.tsx +++ b/web/app/components/billing/plan/index.tsx @@ -23,6 +23,7 @@ import { useEducationVerify } from '@/service/use-education' import { getDaysUntilEndOfMonth } from '@/utils/time' import { Loading } from '../../base/icons/src/public/thought' import { NUM_INFINITE } from '../config' +import { useEducationDiscount } from '../hooks/use-education-discount' import { Plan, SelfHostedPlan } from '../type' import UpgradeBtn from '../upgrade-btn' import AppsInfo from '../usage-info/apps-info' @@ -39,12 +40,13 @@ const PlanComp: FC = ({ const { t } = useTranslation() const router = useRouter() const path = usePathname() - const { userProfile } = useAppContext() + const { userProfile, isCurrentWorkspaceManager } = useAppContext() const { plan, enableEducationPlan, allowRefreshEducationVerify, isEducationAccount } = useProviderContext() const isAboutToExpire = allowRefreshEducationVerify const { type, } = plan + const isEnterprisePlan = String(type) === SelfHostedPlan.enterprise const { usage, @@ -65,6 +67,7 @@ const PlanComp: FC = ({ })() const [showModal, setShowModal] = React.useState(false) + const { handleEducationDiscount, isEducationDiscountLoading } = useEducationDiscount() const { mutateAsync, isPending } = useEducationVerify() const setShowAccountSettingModal = useModalContextSelector(s => s.setShowAccountSettingModal) const unmountedRef = useUnmountedRef() @@ -97,7 +100,7 @@ const PlanComp: FC = ({ {plan.type === Plan.team && ( )} - {(plan.type as any) === SelfHostedPlan.enterprise && ( + {isEnterprisePlan && ( )}
@@ -115,7 +118,14 @@ const PlanComp: FC = ({ {isPending && } )} - {(plan.type as any) !== SelfHostedPlan.enterprise && ( + {enableEducationPlan && isEducationAccount && type === Plan.sandbox && isCurrentWorkspaceManager && ( + + )} + {!isEnterprisePlan && ( { usage: buildUsage(), total: buildUsage(), }, + enableEducationPlan: false, + isEducationAccount: false, }) ;(useGetPricingPageLanguage as Mock).mockImplementation(() => mockLanguage) }) @@ -72,6 +74,39 @@ describe('Pricing', () => { expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByTestId('pricing-link')).toHaveAttribute('href', 'https://dify.ai/en/pricing#plans-and-features') }) + + it('should default to yearly billing for education accounts', () => { + ;(useProviderContext as Mock).mockReturnValue({ + plan: { + type: Plan.sandbox, + usage: buildUsage(), + total: buildUsage(), + }, + enableEducationPlan: true, + isEducationAccount: true, + }) + + render() + + expect(screen.getByRole('switch')).toBeChecked() + }) + + it('should not default to yearly billing for non-manager education accounts', () => { + ;(useAppContext as Mock).mockReturnValue({ isCurrentWorkspaceManager: false }) + ;(useProviderContext as Mock).mockReturnValue({ + plan: { + type: Plan.sandbox, + usage: buildUsage(), + total: buildUsage(), + }, + enableEducationPlan: true, + isEducationAccount: true, + }) + + render() + + expect(screen.getByRole('switch')).not.toBeChecked() + }) }) describe('Props', () => { diff --git a/web/app/components/billing/pricing/index.tsx b/web/app/components/billing/pricing/index.tsx index cd88be5fb3..6d9b0f67cf 100644 --- a/web/app/components/billing/pricing/index.tsx +++ b/web/app/components/billing/pricing/index.tsx @@ -39,9 +39,11 @@ const pricingScrollAreaClassNames = { const Pricing: FC = ({ onCancel, }) => { - const { plan } = useProviderContext() + const { plan, enableEducationPlan, isEducationAccount } = useProviderContext() const { isCurrentWorkspaceManager } = useAppContext() - const [planRange, setPlanRange] = React.useState(PlanRange.monthly) + const shouldDefaultToYearly = isCurrentWorkspaceManager && enableEducationPlan && isEducationAccount + const [selectedPlanRange, setSelectedPlanRange] = React.useState() + const planRange = selectedPlanRange ?? (shouldDefaultToYearly ? PlanRange.yearly : PlanRange.monthly) const [currentCategory, setCurrentCategory] = useState(CategoryEnum.CLOUD) const canPay = isCurrentWorkspaceManager @@ -73,7 +75,7 @@ const Pricing: FC = ({ currentCategory={currentCategory} onChangeCategory={setCurrentCategory} currentPlanRange={planRange} - onChangePlanRange={setPlanRange} + onChangePlanRange={setSelectedPlanRange} /> ({ useAppContext: vi.fn(), })) +vi.mock('@/context/provider-context', () => ({ + useProviderContext: vi.fn(), +})) + vi.mock('@/service/billing', () => ({ fetchSubscriptionUrls: vi.fn(), })) @@ -38,6 +43,7 @@ vi.mock('../../../assets', () => ({ })) const mockUseAppContext = useAppContext as Mock +const mockUseProviderContext = useProviderContext as Mock const mockUseAsyncWindowOpen = useAsyncWindowOpen as Mock const mockBillingInvoices = consoleClient.billing.invoices as Mock const mockFetchSubscriptionUrls = fetchSubscriptionUrls as Mock @@ -72,6 +78,10 @@ beforeEach(() => { vi.clearAllMocks() toast.dismiss() mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: true }) + mockUseProviderContext.mockReturnValue({ + enableEducationPlan: false, + isEducationAccount: false, + }) mockUseAsyncWindowOpen.mockReturnValue(vi.fn(async open => await open())) mockBillingInvoices.mockResolvedValue({ url: 'https://billing.example' }) mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://subscription.example' }) @@ -260,6 +270,127 @@ describe('CloudPlanItem', () => { }) }) + it('should use education discount checkout for yearly professional plan when education account is active', async () => { + mockUseProviderContext.mockReturnValue({ + enableEducationPlan: true, + isEducationAccount: true, + }) + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'education.useEducationDiscount' })) + + await waitFor(() => { + expect(mockFetchSubscriptionUrls).toHaveBeenCalledWith(Plan.professional, 'year') + expect(assignedHref).toBe('https://subscription.example') + }) + }) + + it('should show default CTA and hide warning when current user is not workspace manager', () => { + mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false }) + mockUseProviderContext.mockReturnValue({ + enableEducationPlan: true, + isEducationAccount: true, + }) + + render( + , + ) + + expect(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }))!.toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'education.useEducationDiscount' })).not.toBeInTheDocument() + expect(screen.queryByText('education.planNotSupportEducationDiscount')).not.toBeInTheDocument() + }) + + it('should hide education unsupported warning when current user is not workspace manager', () => { + mockUseAppContext.mockReturnValue({ isCurrentWorkspaceManager: false }) + mockUseProviderContext.mockReturnValue({ + enableEducationPlan: true, + isEducationAccount: true, + }) + + render( + , + ) + + expect(screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }))!.toBeInTheDocument() + expect(screen.queryByText('education.planNotSupportEducationDiscount')).not.toBeInTheDocument() + }) + + it('should show education unsupported warning below the button without changing button text or blocking checkout', async () => { + mockUseProviderContext.mockReturnValue({ + enableEducationPlan: true, + isEducationAccount: true, + }) + + render( + , + ) + + const button = screen.getByRole('button', { name: 'billing.plansCommon.startBuilding' }) + expect(button)!.not.toBeDisabled() + expect(screen.getByText('education.planNotSupportEducationDiscount'))!.toBeInTheDocument() + + fireEvent.click(button) + expect(screen.getByText('education.educationPricingConfirm.title'))!.toBeInTheDocument() + expect(screen.getByText(/^education\.educationPricingConfirm\.description/))!.toBeInTheDocument() + expect(screen.queryByRole('button', { name: 'common.operation.close' }))!.not.toBeInTheDocument() + expect(screen.getByRole('button', { name: 'education.educationPricingConfirm.cancel' }))!.toBeInTheDocument() + fireEvent.click(screen.getByRole('button', { name: 'education.educationPricingConfirm.continue' })) + + await waitFor(() => { + expect(mockFetchSubscriptionUrls).toHaveBeenCalledWith(Plan.professional, 'month') + expect(assignedHref).toBe('https://subscription.example') + }) + }) + + it('should close the unsupported plan confirm without checkout when canceled', async () => { + mockUseProviderContext.mockReturnValue({ + enableEducationPlan: true, + isEducationAccount: true, + }) + + render( + , + ) + + fireEvent.click(screen.getByRole('button', { name: 'billing.plansCommon.getStarted' })) + fireEvent.click(screen.getByRole('button', { name: 'education.educationPricingConfirm.cancel' })) + + await waitFor(() => { + expect(screen.queryByText('education.educationPricingConfirm.title'))!.not.toBeInTheDocument() + }) + expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled() + expect(assignedHref).toBe('') + }) + // Covers L62-63: loading guard prevents double click it('should ignore second click while loading', async () => { // Make the first fetch hang until we resolve it diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/button.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/button.tsx index 8115646748..5e3f1cab0d 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/button.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/button.tsx @@ -1,6 +1,5 @@ import type { BasicPlan } from '../../../type' import { cn } from '@langgenius/dify-ui/cn' -import { RiArrowRightLine } from '@remixicon/react' import * as React from 'react' import { Plan } from '../../../type' @@ -24,6 +23,7 @@ type ButtonProps = { isPlanDisabled: boolean btnText: string handleGetPayUrl: () => void + warningText?: string } const Button = ({ @@ -31,22 +31,30 @@ const Button = ({ isPlanDisabled, btnText, handleGetPayUrl, + warningText, }: ButtonProps) => { return ( - + {warningText && ( +
+ {warningText} +
)} - onClick={handleGetPayUrl} - > - {btnText} - {!isPlanDisabled && } - +
) } diff --git a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx index 53d5025f08..d3dc47b29f 100644 --- a/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx +++ b/web/app/components/billing/pricing/plans/cloud-plan-item/index.tsx @@ -1,15 +1,26 @@ 'use client' import type { FC } from 'react' import type { BasicPlan } from '../../../type' +import { + AlertDialog, + AlertDialogActions, + AlertDialogCancelButton, + AlertDialogConfirmButton, + AlertDialogContent, + AlertDialogDescription, + AlertDialogTitle, +} from '@langgenius/dify-ui/alert-dialog' import { toast } from '@langgenius/dify-ui/toast' import * as React from 'react' import { useMemo } from 'react' import { useTranslation } from 'react-i18next' import { useAppContext } from '@/context/app-context' +import { useProviderContext } from '@/context/provider-context' import { useAsyncWindowOpen } from '@/hooks/use-async-window-open' import { fetchSubscriptionUrls } from '@/service/billing' import { consoleClient } from '@/service/client' import { ALL_PLANS } from '../../../config' +import { useEducationDiscount } from '../../../hooks/use-education-discount' import { Plan } from '../../../type' import { Professional, Sandbox, Team } from '../../assets' import { PlanRange } from '../../plan-switcher/plan-range-switcher' @@ -22,6 +33,10 @@ const ICON_MAP = { [Plan.team]: , } +type ConfirmType = { + type: 'info' | 'warning' +} + type CloudPlanItemProps = { currentPlan: BasicPlan plan: BasicPlan @@ -33,6 +48,7 @@ const CloudPlanItem: FC = ({ plan, currentPlan, planRange, + canPay, }) => { const { t } = useTranslation() const [loading, setLoading] = React.useState(false) @@ -45,9 +61,23 @@ const CloudPlanItem: FC = ({ const isCurrentPaidPlan = isCurrent && !isFreePlan const isPlanDisabled = isCurrentPaidPlan ? false : planInfo.level <= ALL_PLANS[currentPlan].level const { isCurrentWorkspaceManager } = useAppContext() + const { enableEducationPlan, isEducationAccount } = useProviderContext() + const isEducationDiscountMode = enableEducationPlan && isEducationAccount + const isEducationDiscountSupportedPlan = plan === Plan.professional && isYear + const selectedPlanName = t(`${i18nPrefix}.name`, { ns: 'billing' }) + const selectedBillingPeriod = t(`educationPricingConfirm.billingPeriod.${isYear ? 'yearly' : 'monthly'}`, { ns: 'education' }) + const educationDiscountWarningText = canPay && isEducationDiscountMode && !isFreePlan && !isEducationDiscountSupportedPlan + ? t('planNotSupportEducationDiscount', { ns: 'education' }) + : undefined const openAsyncWindow = useAsyncWindowOpen() + const { handleEducationDiscount, isEducationDiscountLoading } = useEducationDiscount() + const [showEducationPricingConfirm, setShowEducationPricingConfirm] = React.useState(false) + const educationPricingConfirmInfo: ConfirmType = { type: 'warning' } const btnText = useMemo(() => { + if (canPay && isEducationDiscountMode && isEducationDiscountSupportedPlan && !isCurrent) + return t('useEducationDiscount', { ns: 'education' }) + if (isCurrent) return t('plansCommon.currentPlan', { ns: 'billing' }) @@ -56,15 +86,20 @@ const CloudPlanItem: FC = ({ [Plan.professional]: t('plansCommon.startBuilding', { ns: 'billing' }), [Plan.team]: t('plansCommon.getStarted', { ns: 'billing' }), })[plan] - }, [isCurrent, plan, t]) + }, [canPay, isCurrent, isEducationDiscountMode, isEducationDiscountSupportedPlan, plan, t]) - const handleGetPayUrl = async () => { - if (loading) + const handlePayCurrentPlan = async () => { + if (loading || isEducationDiscountLoading) return if (isPlanDisabled) return + if (isEducationDiscountMode && isEducationDiscountSupportedPlan && !isCurrentPaidPlan) { + await handleEducationDiscount() + return + } + if (!isCurrentWorkspaceManager) { toast.error(t('buyPermissionDeniedTip', { ns: 'billing' })) return @@ -96,6 +131,18 @@ const CloudPlanItem: FC = ({ setLoading(false) } } + const handleGetPayUrl = async () => { + if (educationDiscountWarningText && !isPlanDisabled) { + setShowEducationPricingConfirm(true) + return + } + + await handlePayCurrentPlan() + } + const handleContinueCurrentPlan = async () => { + setShowEducationPricingConfirm(false) + await handlePayCurrentPlan() + } return (
@@ -146,9 +193,46 @@ const CloudPlanItem: FC = ({ isPlanDisabled={isPlanDisabled} btnText={btnText} handleGetPayUrl={handleGetPayUrl} + warningText={educationDiscountWarningText} />
+ + {showEducationPricingConfirm &&
} + +
+ + {t('educationPricingConfirm.title', { ns: 'education' })} + + + {t('educationPricingConfirm.description', { + ns: 'education', + planName: selectedPlanName, + billingPeriod: selectedBillingPeriod, + })} + +
+ + setShowEducationPricingConfirm(false)} + disabled={loading} + > + {t('educationPricingConfirm.cancel', { ns: 'education' })} + + + {t('educationPricingConfirm.continue', { ns: 'education' })} + + +
+
) } diff --git a/web/app/components/datasets/list/__tests__/datasets.spec.tsx b/web/app/components/datasets/list/__tests__/datasets.spec.tsx index 5b777e0b2e..f78622a9cd 100644 --- a/web/app/components/datasets/list/__tests__/datasets.spec.tsx +++ b/web/app/components/datasets/list/__tests__/datasets.spec.tsx @@ -56,8 +56,6 @@ vi.mock('@/context/app-context', () => ({ // Mock useDatasetCardState hook vi.mock('../dataset-card/hooks/use-dataset-card-state', () => ({ useDatasetCardState: () => ({ - tags: [], - setTags: vi.fn(), modalState: { showRenameModal: false, showConfirmDelete: false, @@ -77,6 +75,14 @@ vi.mock('../../rename-modal', () => ({ default: () => null, })) +vi.mock('../dataset-card', () => ({ + default: ({ dataset }: { dataset: DataSet }) => ( +
+ {dataset.name} +
+ ), +})) + function createMockDataset(overrides: Partial = {}): DataSet { return { id: 'dataset-1', diff --git a/web/app/components/datasets/list/__tests__/index.spec.tsx b/web/app/components/datasets/list/__tests__/index.spec.tsx index adc53debbd..7e46c46f1a 100644 --- a/web/app/components/datasets/list/__tests__/index.spec.tsx +++ b/web/app/components/datasets/list/__tests__/index.spec.tsx @@ -36,11 +36,6 @@ vi.mock('@/context/external-api-panel-context', () => ({ }), })) -// Mock tag management store -vi.mock('@/app/components/base/tag-management/store', () => ({ - useStore: () => false, -})) - // Mock useDocumentTitle hook vi.mock('@/hooks/use-document-title', () => ({ default: vi.fn(), @@ -108,15 +103,16 @@ vi.mock('@/app/components/develop/secret-key/secret-key-modal', () => ({ })) // Mock TagManagementModal -vi.mock('@/app/components/base/tag-management', () => ({ - default: () =>
, +vi.mock('@/features/tag-management/components/tag-management-modal', () => ({ + TagManagementModal: ({ show }: { show: boolean }) => show ?
: null, })) // Mock TagFilter -vi.mock('@/app/components/base/tag-management/filter', () => ({ - default: ({ onChange }: { value: string[], onChange: (val: string[]) => void }) => ( +vi.mock('@/features/tag-management/components/tag-filter', () => ({ + TagFilter: ({ onChange, onOpenTagManagement }: { value: string[], onChange: (val: string[]) => void, onOpenTagManagement: () => void }) => (
+
), })) @@ -226,7 +222,7 @@ describe('List', () => { it('should have correct container styling', () => { const { container } = render() const mainContainer = container.firstChild as HTMLElement - expect(mainContainer).toHaveClass('scroll-container', 'relative', 'flex', 'grow', 'flex-col') + expect(mainContainer).toHaveClass('relative', 'flex', 'grow', 'flex-col') }) }) @@ -312,15 +308,9 @@ describe('List', () => { expect(mockSetShowExternalApiPanel).toHaveBeenCalledWith(false) }) - it('should show TagManagementModal when showTagManagementModal is true', async () => { - vi.doMock('@/app/components/base/tag-management/store', () => ({ - useStore: () => true, // showTagManagementModal is true - })) - - vi.resetModules() - const { default: ListComponent } = await import('../index') - - render() + it('should show TagManagementModal when tag management is opened', () => { + render() + fireEvent.click(screen.getByText('Manage Tags')) expect(screen.getByTestId('tag-management-modal')).toBeInTheDocument() }) diff --git a/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx b/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx index f6c7e1e93d..55176faf47 100644 --- a/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx +++ b/web/app/components/datasets/list/dataset-card/__tests__/index.spec.tsx @@ -30,8 +30,6 @@ vi.mock('@/context/app-context', () => ({ vi.mock('../hooks/use-dataset-card-state', () => ({ useDatasetCardState: () => ({ - tags: [], - setTags: vi.fn(), modalState: { showRenameModal: false, showConfirmDelete: false, @@ -55,8 +53,8 @@ vi.mock('../components/dataset-card-header', () => ({ vi.mock('../components/dataset-card-modals', () => ({ default: () =>
, })) -vi.mock('../components/tag-area', () => ({ - default: ({ onClick }: { onClick: (e: React.MouseEvent) => void, ref?: React.Ref }) => ( +vi.mock('@/features/tag-management/components/dataset-card-tags', () => ({ + DatasetCardTags: ({ onClick }: { onClick: (e: React.MouseEvent) => void }) => (
), })) diff --git a/web/app/components/datasets/list/dataset-card/components/__tests__/tag-area.spec.tsx b/web/app/components/datasets/list/dataset-card/components/__tests__/tag-area.spec.tsx deleted file mode 100644 index 2858469cdb..0000000000 --- a/web/app/components/datasets/list/dataset-card/components/__tests__/tag-area.spec.tsx +++ /dev/null @@ -1,198 +0,0 @@ -import type { Tag } from '@/app/components/base/tag-management/constant' -import type { DataSet } from '@/models/datasets' -import { fireEvent, render, screen } from '@testing-library/react' -import { useRef } from 'react' -import { describe, expect, it, vi } from 'vitest' -import { IndexingType } from '@/app/components/datasets/create/step-two' -import { ChunkingMode, DatasetPermission, DataSourceType } from '@/models/datasets' -import TagArea from '../tag-area' - -// Mock TagSelector as it's a complex component from base -vi.mock('@/app/components/base/tag-management/selector', () => ({ - default: ({ value, selectedTags, onCacheUpdate, onChange }: { - value: string[] - selectedTags: Tag[] - onCacheUpdate: (tags: Tag[]) => void - onChange?: () => void - }) => ( -
-
{value.join(',')}
-
- {selectedTags.length} - {' '} - tags -
- - -
- ), -})) - -describe('TagArea', () => { - const createMockDataset = (overrides: Partial = {}): DataSet => ({ - id: 'dataset-1', - name: 'Test Dataset', - description: 'Test description', - provider: 'vendor', - permission: DatasetPermission.allTeamMembers, - data_source_type: DataSourceType.FILE, - indexing_technique: IndexingType.QUALIFIED, - embedding_available: true, - app_count: 5, - document_count: 10, - word_count: 1000, - updated_at: 1609545600, - tags: [], - embedding_model: 'text-embedding-ada-002', - embedding_model_provider: 'openai', - created_by: 'user-1', - doc_form: ChunkingMode.text, - ...overrides, - } as DataSet) - - const mockTags: Tag[] = [ - { id: 'tag-1', name: 'Tag 1', type: 'knowledge', binding_count: 0 }, - { id: 'tag-2', name: 'Tag 2', type: 'knowledge', binding_count: 0 }, - ] - - const defaultProps = { - dataset: createMockDataset(), - tags: mockTags, - setTags: vi.fn(), - onSuccess: vi.fn(), - isHoveringTagSelector: false, - onClick: vi.fn(), - } - - beforeEach(() => { - vi.clearAllMocks() - }) - - describe('Rendering', () => { - it('should render without crashing', () => { - render() - expect(screen.getByTestId('tag-selector')).toBeInTheDocument() - }) - - it('should render TagSelector with correct value', () => { - render() - expect(screen.getByTestId('tag-values')).toHaveTextContent('tag-1,tag-2') - }) - - it('should display selected tags count', () => { - render() - expect(screen.getByTestId('selected-count')).toHaveTextContent('2 tags') - }) - }) - - describe('Props', () => { - it('should pass dataset id to TagSelector', () => { - const dataset = createMockDataset({ id: 'custom-dataset-id' }) - render() - expect(screen.getByTestId('tag-selector')).toBeInTheDocument() - }) - - it('should render with empty tags', () => { - render() - expect(screen.getByTestId('selected-count')).toHaveTextContent('0 tags') - }) - - it('should forward ref correctly', () => { - const TestComponent = () => { - const ref = useRef(null) - return - } - render() - expect(screen.getByTestId('tag-selector')).toBeInTheDocument() - }) - }) - - describe('User Interactions', () => { - it('should call onClick when container is clicked', () => { - const onClick = vi.fn() - const { container } = render() - - const wrapper = container.firstChild as HTMLElement - fireEvent.click(wrapper) - - expect(onClick).toHaveBeenCalledTimes(1) - }) - - it('should call setTags when tags are updated', () => { - const setTags = vi.fn() - render() - - fireEvent.click(screen.getByText('Update Tags')) - - expect(setTags).toHaveBeenCalledWith([{ id: 'new-tag', name: 'New Tag', type: 'knowledge', binding_count: 0 }]) - }) - - it('should call onSuccess when onChange is triggered', () => { - const onSuccess = vi.fn() - render() - - fireEvent.click(screen.getByText('Trigger Change')) - - expect(onSuccess).toHaveBeenCalledTimes(1) - }) - }) - - describe('Styles', () => { - it('should have opacity class when embedding is not available', () => { - const dataset = createMockDataset({ embedding_available: false }) - const { container } = render() - const wrapper = container.firstChild as HTMLElement - expect(wrapper).toHaveClass('opacity-30') - }) - - it('should not have opacity class when embedding is available', () => { - const dataset = createMockDataset({ embedding_available: true }) - const { container } = render() - const wrapper = container.firstChild as HTMLElement - expect(wrapper).not.toHaveClass('opacity-30') - }) - - it('should show mask when not hovering and has tags', () => { - const { container } = render() - const maskDiv = container.querySelector('.bg-tag-selector-mask-bg') - expect(maskDiv).toBeInTheDocument() - expect(maskDiv).not.toHaveClass('hidden') - }) - - it('should hide mask when hovering', () => { - const { container } = render() - // When hovering, the mask div should have 'hidden' class - const maskDiv = container.querySelector('.absolute.right-0.top-0') - expect(maskDiv).toHaveClass('hidden') - }) - - it('should make TagSelector visible when tags exist', () => { - const { container } = render() - const tagSelectorWrapper = container.querySelector('.visible') - expect(tagSelectorWrapper).toBeInTheDocument() - }) - }) - - describe('Edge Cases', () => { - it('should handle undefined onSuccess', () => { - render() - // Should not throw when clicking Trigger Change - expect(() => fireEvent.click(screen.getByText('Trigger Change'))).not.toThrow() - }) - - it('should handle many tags', () => { - const manyTags: Tag[] = Array.from({ length: 20 }, (_, i) => ({ - id: `tag-${i}`, - name: `Tag ${i}`, - type: 'knowledge' as const, - binding_count: 0, - })) - render() - expect(screen.getByTestId('selected-count')).toHaveTextContent('20 tags') - }) - }) -}) diff --git a/web/app/components/datasets/list/dataset-card/components/tag-area.tsx b/web/app/components/datasets/list/dataset-card/components/tag-area.tsx deleted file mode 100644 index 2c8d6aa73a..0000000000 --- a/web/app/components/datasets/list/dataset-card/components/tag-area.tsx +++ /dev/null @@ -1,55 +0,0 @@ -import type { Tag } from '@/app/components/base/tag-management/constant' -import type { DataSet } from '@/models/datasets' -import { cn } from '@langgenius/dify-ui/cn' -import * as React from 'react' -import TagSelector from '@/app/components/base/tag-management/selector' - -type TagAreaProps = { - dataset: DataSet - tags: Tag[] - setTags: (tags: Tag[]) => void - onSuccess?: () => void - isHoveringTagSelector: boolean - onClick: (e: React.MouseEvent) => void -} - -const TagArea = React.forwardRef(({ - dataset, - tags, - setTags, - onSuccess, - isHoveringTagSelector, - onClick, -}, ref) => ( -
-
0 && 'visible', - )} - > - tag.id)} - selectedTags={tags} - onCacheUpdate={setTags} - onChange={onSuccess} - /> -
-
-
-)) -TagArea.displayName = 'TagArea' - -export default TagArea diff --git a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts index 7d07bcf9d0..fa4868f391 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/__tests__/use-dataset-card-state.spec.ts @@ -66,15 +66,6 @@ describe('useDatasetCardState', () => { }) describe('Initial State', () => { - it('should return tags from dataset', () => { - const dataset = createMockDataset() - const { result } = renderHook(() => - useDatasetCardState({ dataset, onSuccess: vi.fn() }), - ) - - expect(result.current.tags).toEqual(dataset.tags) - }) - it('should have initial modal state closed', () => { const dataset = createMockDataset() const { result } = renderHook(() => @@ -96,36 +87,6 @@ describe('useDatasetCardState', () => { }) }) - describe('Tags State', () => { - it('should update tags when setTags is called', () => { - const dataset = createMockDataset() - const { result } = renderHook(() => - useDatasetCardState({ dataset, onSuccess: vi.fn() }), - ) - - act(() => { - result.current.setTags([{ id: 'tag-2', name: 'Tag 2', type: 'knowledge', binding_count: 0 }]) - }) - - expect(result.current.tags).toEqual([{ id: 'tag-2', name: 'Tag 2', type: 'knowledge', binding_count: 0 }]) - }) - - it('should sync tags when dataset tags change', () => { - const dataset = createMockDataset() - const { result, rerender } = renderHook( - ({ dataset }) => useDatasetCardState({ dataset, onSuccess: vi.fn() }), - { initialProps: { dataset } }, - ) - - const newTags = [{ id: 'tag-3', name: 'Tag 3', type: 'knowledge', binding_count: 0 }] - const updatedDataset = createMockDataset({ tags: newTags }) - - rerender({ dataset: updatedDataset }) - - expect(result.current.tags).toEqual(newTags) - }) - }) - describe('Modal Handlers', () => { it('should open rename modal when openRenameModal is called', () => { const dataset = createMockDataset() @@ -279,15 +240,6 @@ describe('useDatasetCardState', () => { }) describe('Edge Cases', () => { - it('should handle empty tags array', () => { - const dataset = createMockDataset({ tags: [] }) - const { result } = renderHook(() => - useDatasetCardState({ dataset, onSuccess: vi.fn() }), - ) - - expect(result.current.tags).toEqual([]) - }) - it('should handle undefined onSuccess', async () => { const dataset = createMockDataset() const { result } = renderHook(() => diff --git a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts index 6cffbb6828..88aa7b50ae 100644 --- a/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts +++ b/web/app/components/datasets/list/dataset-card/hooks/use-dataset-card-state.ts @@ -1,7 +1,6 @@ -import type { Tag } from '@/app/components/base/tag-management/constant' import type { DataSet } from '@/models/datasets' import { toast } from '@langgenius/dify-ui/toast' -import { useCallback, useEffect, useState } from 'react' +import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useCheckDatasetUsage, useDeleteDataset } from '@/service/use-dataset-card' import { useExportPipelineDSL } from '@/service/use-pipeline' @@ -20,11 +19,6 @@ type UseDatasetCardStateOptions = { export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateOptions) => { const { t } = useTranslation() - const [tags, setTags] = useState(dataset.tags) - - useEffect(() => { - setTags(dataset.tags) - }, [dataset.tags]) // Modal state const [modalState, setModalState] = useState({ @@ -113,10 +107,6 @@ export const useDatasetCardState = ({ dataset, onSuccess }: UseDatasetCardStateO }, [dataset.id, deleteDatasetMutation, onSuccess, t, closeConfirmDelete]) return { - // Tag state - tags, - setTags, - // Modal state modalState, openRenameModal, diff --git a/web/app/components/datasets/list/dataset-card/index.tsx b/web/app/components/datasets/list/dataset-card/index.tsx index 5bd032d151..3fe4b6f7c0 100644 --- a/web/app/components/datasets/list/dataset-card/index.tsx +++ b/web/app/components/datasets/list/dataset-card/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { DataSet } from '@/models/datasets' -import { useHover } from 'ahooks' -import { useMemo, useRef } from 'react' +import { useMemo } from 'react' import { useSelector as useAppContextWithSelector } from '@/context/app-context' +import { DatasetCardTags } from '@/features/tag-management/components/dataset-card-tags' import { useRouter } from '@/next/navigation' import CornerLabels from './components/corner-labels' import DatasetCardFooter from './components/dataset-card-footer' @@ -10,29 +10,27 @@ import DatasetCardHeader from './components/dataset-card-header' import DatasetCardModals from './components/dataset-card-modals' import Description from './components/description' import OperationsDropdown from './components/operations-dropdown' -import TagArea from './components/tag-area' -import { useDatasetCardState } from './hooks/use-dataset-card-state' +import { useDatasetCardState as useDatasetCardController } from './hooks/use-dataset-card-state' const EXTERNAL_PROVIDER = 'external' type DatasetCardProps = { dataset: DataSet onSuccess?: () => void + onOpenTagManagement?: () => void } const DatasetCard = ({ dataset, onSuccess, + onOpenTagManagement = () => {}, }: DatasetCardProps) => { const { push } = useRouter() const isCurrentWorkspaceDatasetOperator = useAppContextWithSelector(state => state.isCurrentWorkspaceDatasetOperator) - const tagSelectorRef = useRef(null) - const isHoveringTagSelector = useHover(tagSelectorRef) + const datasetCard = useDatasetCardController({ dataset, onSuccess }) const { - tags, - setTags, modalState, openRenameModal, closeRenameModal, @@ -40,7 +38,7 @@ const DatasetCard = ({ handleExportPipeline, detectIsUsedByApp, onConfirmDelete, - } = useDatasetCardState({ dataset, onSuccess }) + } = datasetCard const isExternalProvider = dataset.provider === EXTERNAL_PROVIDER const isPipelineUnpublished = useMemo(() => { @@ -72,14 +70,13 @@ const DatasetCard = ({ - void } const Datasets = ({ tags, keywords, includeAll, + onOpenTagManagement = () => {}, }: Props) => { const { t } = useTranslation() const isCurrentWorkspaceEditor = useAppContextWithSelector(state => state.isCurrentWorkspaceEditor) @@ -60,7 +62,7 @@ const Datasets = ({