From 59e99ee1ae0dd9e5eb822f165090a55de7a334f7 Mon Sep 17 00:00:00 2001 From: chariri Date: Tue, 26 May 2026 17:20:10 +0900 Subject: [PATCH] refactor(api): migrate console tags to tenant/user via DI and improve tests (#36658) Co-authored-by: Asuka Minato Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- api/controllers/console/files.py | 2 +- api/controllers/console/remote_files.py | 2 +- api/controllers/console/tag/tags.py | 46 ++++--- .../console/datasets/test_external.py | 116 ++++++++++++++++++ .../controllers/console/test_feature.py | 65 ++++++++++ .../controllers/console/test_files.py | 101 +++++++++++++++ .../console/datasets/test_external.py | 10 +- .../controllers/console/tag/test_tags.py | 82 ++++--------- .../controllers/console/test_files.py | 49 ++++---- .../controllers/console/test_remote_files.py | 19 ++- .../controllers/console/test_wraps.py | 63 +++++++++- 11 files changed, 441 insertions(+), 114 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/test_feature.py create mode 100644 api/tests/test_containers_integration_tests/controllers/console/test_files.py diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 3ef006c051..5197120c13 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -28,7 +28,7 @@ from controllers.console.wraps import ( from extensions.ext_database import db from fields.file_fields import FileResponse, UploadConfig from libs.login import login_required -from models.account import Account +from models import Account from services.file_service import FileService from . import console_ns diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 93435d1151..9f7fe6379c 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -18,7 +18,7 @@ from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from graphon.file import helpers as file_helpers from libs.login import login_required -from models.account import Account +from models import Account from services.file_service import FileService diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index a37e56e2b8..4e2ea2060d 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_tenant_id, + with_current_user, +) from fields.base import ResponseModel -from libs.login import current_account_with_tenant, login_required +from libs.login import login_required +from models import Account from models.enums import TagType from services.tag_service import ( SaveTagPayload, @@ -92,8 +99,8 @@ class TagListApi(Resource): params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."} ) @console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])}) - def get(self): - _, current_tenant_id = current_account_with_tenant() + @with_current_tenant_id + def get(self, current_tenant_id: str): raw_args = request.args.to_dict() param = TagListQueryParam.model_validate(raw_args) tags = TagService.get_tags(param.type, current_tenant_id, param.keyword) @@ -109,9 +116,9 @@ class TagListApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - current_user, _ = current_account_with_tenant() - # The role of the current user in the ta table must be admin, owner, or editor + @with_current_user + def post(self, current_user: Account): + # Allow users with edit permission, or dataset editors (including dataset operators). if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() @@ -132,8 +139,8 @@ class TagUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, tag_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def patch(self, current_user: Account, tag_id: UUID): tag_id_str = str(tag_id) # The role of the current user in the ta table must be admin, owner, or editor if not (current_user.has_edit_permission or current_user.is_dataset_editor): @@ -163,20 +170,19 @@ class TagUpdateDeleteApi(Resource): return "", 204 -def _require_tag_binding_edit_permission() -> None: +def _require_tag_binding_edit_permission(current_user: Account) -> None: """ Ensure the current account can edit tag bindings. Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant. """ - current_user, _ = current_account_with_tenant() # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() -def _create_tag_bindings() -> tuple[dict[str, str], int]: - _require_tag_binding_edit_permission() +def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission(current_user) payload = TagBindingPayload.model_validate(console_ns.payload or {}) TagService.save_tag_binding( @@ -189,8 +195,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]: return {"result": "success"}, 200 -def _remove_tag_bindings() -> tuple[dict[str, str], int]: - _require_tag_binding_edit_permission() +def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]: + _require_tag_binding_edit_permission(current_user) payload = TagBindingRemovePayload.model_validate(console_ns.payload or {}) TagService.delete_tag_binding( @@ -213,8 +219,9 @@ class TagBindingCollectionApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - return _create_tag_bindings() + @with_current_user + def post(self, current_user: Account): + return _create_tag_bindings(current_user) @console_ns.route("/tag-bindings/remove") @@ -228,5 +235,6 @@ class TagBindingRemoveApi(Resource): @setup_required @login_required @account_initialization_required - def post(self): - return _remove_tag_bindings() + @with_current_user + def post(self, current_user: Account): + return _remove_tag_bindings(current_user) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py new file mode 100644 index 0000000000..d6b7e9e636 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_external.py @@ -0,0 +1,116 @@ +"""Integration tests for console external knowledge API endpoints.""" + +from __future__ import annotations + +import json + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from models.dataset import ExternalKnowledgeApis +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def _create_external_api( + db_session: Session, + *, + tenant_id: str, + account_id: str, + name: str, +) -> ExternalKnowledgeApis: + external_api = ExternalKnowledgeApis( + tenant_id=tenant_id, + created_by=account_id, + updated_by=account_id, + name=name, + description=f"{name} description", + settings=json.dumps( + { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + ), + ) + db_session.add(external_api) + db_session.commit() + return external_api + + +def test_external_api_template_list_filters_paginates_and_scopes_to_authenticated_tenant( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise the real list route, including query parsing, DB lookup, and tenant isolation.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers) + account_id = account.id + tenant_id = tenant.id + foreign_account_id = foreign_account.id + foreign_tenant_id = foreign_tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + + _create_external_api( + db_session_with_containers, + tenant_id=tenant_id, + account_id=account_id, + name="Alpha Primary", + ) + _create_external_api( + db_session_with_containers, + tenant_id=tenant_id, + account_id=account_id, + name="Alpha Secondary", + ) + _create_external_api( + db_session_with_containers, + tenant_id=tenant_id, + account_id=account_id, + name="Beta Unmatched", + ) + _create_external_api( + db_session_with_containers, + tenant_id=foreign_tenant_id, + account_id=foreign_account_id, + name="Alpha Foreign", + ) + + response = test_client_with_containers.get( + "/console/api/datasets/external-knowledge-api?page=1&limit=1&keyword=Alpha", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json is not None + assert response.json["page"] == 1 + assert response.json["limit"] == 1 + assert response.json["total"] == 2 + assert response.json["has_more"] is True + assert len(response.json["data"]) == 1 + + first_page_item = response.json["data"][0] + assert first_page_item["tenant_id"] == tenant_id + assert first_page_item["name"] in {"Alpha Primary", "Alpha Secondary"} + assert first_page_item["settings"] == { + "endpoint": "https://example.com", + "api_key": "test-api-key", + } + assert first_page_item["dataset_bindings"] == [] + + second_response = test_client_with_containers.get( + "/console/api/datasets/external-knowledge-api?page=2&limit=1&keyword=Alpha", + headers=headers, + ) + + assert second_response.status_code == 200 + assert second_response.json is not None + assert second_response.json["page"] == 2 + assert second_response.json["limit"] == 1 + assert second_response.json["total"] == 2 + assert len(second_response.json["data"]) == 1 + + second_page_item = second_response.json["data"][0] + assert second_page_item["name"] in {"Alpha Primary", "Alpha Secondary"} + assert second_response.json["data"][0]["tenant_id"] == tenant_id diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_feature.py b/api/tests/test_containers_integration_tests/controllers/console/test_feature.py new file mode 100644 index 0000000000..9eb76c8152 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/test_feature.py @@ -0,0 +1,65 @@ +"""Integration tests for console feature endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +from flask.testing import FlaskClient +from sqlalchemy.orm import Session + +from services.feature_service import FeatureModel, FeatureService, LimitationModel +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_feature_list_returns_current_tenant_configuration_without_vector_space( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise auth, tenant injection, and the feature response shaping contract.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + feature_model = FeatureModel( + members=LimitationModel(size=1, limit=2), + apps=LimitationModel(size=3, limit=4), + vector_space=LimitationModel(size=5, limit=6), + ) + + with patch.object(FeatureService, "get_features", return_value=feature_model) as get_features: + response = test_client_with_containers.get( + "/console/api/features", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json is not None + assert response.json["members"] == {"size": 1, "limit": 2} + assert response.json["apps"] == {"size": 3, "limit": 4} + assert "vector_space" not in response.json + get_features.assert_called_once_with(tenant_id, exclude_vector_space=True) + + +def test_feature_vector_space_returns_current_tenant_usage( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise tenant injection and vector-space response serialization through the registered route.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + tenant_id = tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + + vector_space = SimpleNamespace(model_dump=lambda: {"size": 0, "limit": 100}) + + with patch.object(FeatureService, "get_vector_space", return_value=vector_space) as get_vector_space: + response = test_client_with_containers.get( + "/console/api/features/vector-space", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json == {"size": 0, "limit": 100} + get_vector_space.assert_called_once_with(tenant_id) diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_files.py b/api/tests/test_containers_integration_tests/controllers/console/test_files.py new file mode 100644 index 0000000000..8985c1ba66 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/console/test_files.py @@ -0,0 +1,101 @@ +"""Integration tests for console file endpoints.""" + +from __future__ import annotations + +from io import BytesIO + +from flask.testing import FlaskClient +from sqlalchemy import select +from sqlalchemy.orm import Session + +from configs import dify_config +from models.model import UploadFile +from tests.test_containers_integration_tests.controllers.console.helpers import ( + authenticate_console_client, + create_console_account_and_tenant, +) + + +def test_file_upload_config_returns_console_limits( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise the authenticated upload-config route and response contract.""" + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + headers = authenticate_console_client(test_client_with_containers, account) + + response = test_client_with_containers.get( + "/console/api/files/upload", + headers=headers, + ) + + assert response.status_code == 200 + assert response.json == { + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT, + "file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT, + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + "image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT, + "single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT, + "attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT, + } + + +def test_file_upload_persists_file_for_authenticated_current_user( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise real upload behavior plus current-user and tenant propagation.""" + account, tenant = create_console_account_and_tenant(db_session_with_containers) + account_id = account.id + tenant_id = tenant.id + headers = authenticate_console_client(test_client_with_containers, account) + content = b"hello from console integration" + + response = test_client_with_containers.post( + "/console/api/files/upload", + headers=headers, + data={"file": (BytesIO(content), "tenant-owned.txt")}, + content_type="multipart/form-data", + ) + + assert response.status_code == 201 + assert response.json is not None + assert response.json["name"] == "tenant-owned.txt" + assert response.json["size"] == len(content) + assert response.json["extension"] == "txt" + assert response.json["mime_type"] == "text/plain" + assert response.json["created_by"] == account_id + + upload_file = db_session_with_containers.scalar( + select(UploadFile).where(UploadFile.id == response.json["id"]).limit(1) + ) + assert upload_file is not None + assert upload_file.tenant_id == tenant_id + assert upload_file.created_by == account_id + assert upload_file.name == "tenant-owned.txt" + assert upload_file.size == len(content) + assert f"/{tenant_id}/" in upload_file.key + + +def test_file_upload_rejects_missing_file_after_authentication( + db_session_with_containers: Session, + test_client_with_containers: FlaskClient, +) -> None: + """Exercise the route's validation path with a real authenticated account.""" + account, _tenant = create_console_account_and_tenant(db_session_with_containers) + headers = authenticate_console_client(test_client_with_containers, account) + + response = test_client_with_containers.post( + "/console/api/files/upload", + headers=headers, + data={}, + content_type="multipart/form-data", + ) + + assert response.status_code == 400 + assert response.json is not None + assert response.json["code"] == "no_file_uploaded" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py index 3ed65b1ffb..3e76e6c21a 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_external.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -70,13 +70,14 @@ class TestExternalApiTemplateListApi: ExternalDatasetService, "get_external_knowledge_apis", return_value=([api_item], 1), - ), + ) as get_external_knowledge_apis, ): - resp, status = method(api, "id") + resp, status = method(api, "tenant-1") assert status == 200 assert resp["total"] == 1 assert resp["data"][0]["id"] == "1" + get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) def test_post_forbidden(self, app: Flask, current_user): current_user.is_dataset_editor = False @@ -321,13 +322,14 @@ class TestExternalApiTemplateListApiAdvanced: patch( "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis", return_value=(templates, 25), - ), + ) as get_external_knowledge_apis, ): - resp, status = method(api, "id") + resp, status = method(api, "tenant-1") assert status == 200 assert resp["total"] == 25 assert len(resp["data"]) == 3 + get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None) class TestExternalDatasetCreateApiAdvanced: diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py index f4916f013c..32b39de515 100644 --- a/api/tests/unit_tests/controllers/console/tag/test_tags.py +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -13,6 +13,8 @@ from controllers.console.tag.tags import ( TagListApi, TagUpdateDeleteApi, ) +from models import Account +from models.account import AccountStatus, TenantAccountRole from models.enums import TagType from services.tag_service import UpdateTagPayload @@ -35,20 +37,26 @@ def app(): @pytest.fixture def admin_user(): - return MagicMock( - id="user-1", - has_edit_permission=True, - is_dataset_editor=True, + account = Account( + name="Admin User", + email="admin@example.com", + status=AccountStatus.ACTIVE, ) + account.id = "user-1" + account.role = TenantAccountRole.OWNER + return account @pytest.fixture def readonly_user(): - return MagicMock( - id="user-2", - has_edit_permission=False, - is_dataset_editor=False, + account = Account( + name="Readonly User", + email="readonly@example.com", + status=AccountStatus.ACTIVE, ) + account.id = "user-2" + account.role = TenantAccountRole.NORMAL + return account @pytest.fixture @@ -80,10 +88,6 @@ class TestTagListApi: with app.test_request_context("/?type=knowledge"): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(MagicMock(), "tenant-1"), - ), patch( "controllers.console.tag.tags.TagService.get_tags", return_value=[ @@ -96,7 +100,7 @@ class TestTagListApi: ], ), ): - result, status = method(api) + result, status = method(api, "tenant-1") assert status == 200 assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}] @@ -109,17 +113,13 @@ class TestTagListApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch( "controllers.console.tag.tags.TagService.save_tags", return_value=tag, ), ): - result, status = method(api) + result, status = method(api, admin_user) assert status == 200 assert result["name"] == "test-tag" @@ -133,14 +133,10 @@ class TestTagListApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch(payload), ): with pytest.raises(Forbidden): - method(api) + method(api, readonly_user) class TestTagUpdateDeleteApi: @@ -152,10 +148,6 @@ class TestTagUpdateDeleteApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch( "controllers.console.tag.tags.TagService.update_tags", @@ -166,7 +158,7 @@ class TestTagUpdateDeleteApi: return_value=3, ), ): - result, status = method(api, "tag-1") + result, status = method(api, admin_user, "tag-1") assert status == 200 update_payload, tag_id = update_tags_mock.call_args.args @@ -182,14 +174,10 @@ class TestTagUpdateDeleteApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch(payload), ): with pytest.raises(Forbidden): - method(api, "tag-1") + method(api, readonly_user, "tag-1") def test_delete_success(self, app: Flask, admin_user): api = TagUpdateDeleteApi() @@ -197,10 +185,6 @@ class TestTagUpdateDeleteApi: with ( app.test_request_context("/"), - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, "tenant-1"), - ), patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock, ): result, status = method(api, "tag-1") @@ -222,14 +206,10 @@ class TestTagBindingCollectionApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, ): - result, status = method(api) + result, status = method(api, admin_user) save_mock.assert_called_once() assert status == 200 @@ -241,14 +221,10 @@ class TestTagBindingCollectionApi: with app.test_request_context("/", json={}): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch({}), ): with pytest.raises(Forbidden): - method(api) + method(api, readonly_user) class TestTagBindingRemoveApi: @@ -264,14 +240,10 @@ class TestTagBindingRemoveApi: with app.test_request_context("/", json=payload): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(admin_user, None), - ), payload_patch(payload), patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, ): - result, status = method(api) + result, status = method(api, admin_user) delete_mock.assert_called_once() delete_payload = delete_mock.call_args.args[0] @@ -285,14 +257,10 @@ class TestTagBindingRemoveApi: with app.test_request_context("/", json={}): with ( - patch( - "controllers.console.tag.tags.current_account_with_tenant", - return_value=(readonly_user, None), - ), payload_patch({}), ): with pytest.raises(Forbidden): - method(api) + method(api, readonly_user) class TestTagResponseModel: diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py index d566486664..f6ef1cb824 100644 --- a/api/tests/unit_tests/controllers/console/test_files.py +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -19,6 +19,8 @@ from controllers.console.files import ( FilePreviewApi, FileSupportTypeApi, ) +from models import Account +from models.account import AccountStatus, TenantAccountRole def unwrap(func): @@ -53,14 +55,15 @@ def mock_decorators(): @pytest.fixture def mock_current_user(): - user = MagicMock() - user.is_dataset_editor = True + user = Account(name="Test User", email="user-1@example.com", status=AccountStatus.ACTIVE) + user.id = "user-1" + user.role = TenantAccountRole.OWNER return user @pytest.fixture -def mock_current_tenant_id(): - return "tenant-123" +def mock_account_context(mock_current_user): + return mock_current_user @pytest.fixture @@ -91,15 +94,15 @@ class TestFileApiGet: class TestFileApiPost: - def test_no_file_uploaded(self, app: Flask, mock_current_user): + def test_no_file_uploaded(self, app: Flask, mock_account_context): api = FileApi() post_method = unwrap(api.post) with app.test_request_context(method="POST", data={}): with pytest.raises(NoFileUploadedError): - post_method(api, mock_current_user) + post_method(api, mock_account_context) - def test_too_many_files(self, app: Flask, mock_current_user): + def test_too_many_files(self, app: Flask, mock_account_context): api = FileApi() post_method = unwrap(api.post) @@ -114,9 +117,9 @@ class TestFileApiPost: mock_request.form.get.return_value = None with pytest.raises(TooManyFilesError): - post_method(api, mock_current_user) + post_method(api, mock_account_context) - def test_filename_missing(self, app: Flask, mock_current_user): + def test_filename_missing(self, app: Flask, mock_account_context): api = FileApi() post_method = unwrap(api.post) @@ -126,10 +129,10 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(FilenameNotExistsError): - post_method(api, mock_current_user) + post_method(api, mock_account_context) def test_dataset_upload_without_permission(self, app: Flask, mock_current_user): - mock_current_user.is_dataset_editor = False + mock_current_user.role = TenantAccountRole.NORMAL api = FileApi() post_method = unwrap(api.post) @@ -143,7 +146,7 @@ class TestFileApiPost: with pytest.raises(Forbidden): post_method(api, mock_current_user) - def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service): + def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -171,13 +174,13 @@ class TestFileApiPost: } with app.test_request_context(method="POST", data=data): - response, status = post_method(api, mock_current_user) + response, status = post_method(api, mock_account_context) assert status == 201 assert response["id"] == "file-id-123" assert response["name"] == "test.txt" - def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service): + def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service): """Test that invalid source parameter gets normalized to None""" api = FileApi() post_method = unwrap(api.post) @@ -208,7 +211,7 @@ class TestFileApiPost: } with app.test_request_context(method="POST", data=data): - response, status = post_method(api, mock_current_user) + response, status = post_method(api, mock_account_context) assert status == 201 assert response["id"] == "file-id-456" @@ -217,7 +220,7 @@ class TestFileApiPost: call_kwargs = mock_file_service.upload_file.call_args[1] assert call_kwargs["source"] is None - def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service): + def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -232,9 +235,9 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(FileTooLargeError): - post_method(api, mock_current_user) + post_method(api, mock_account_context) - def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service): + def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -249,9 +252,9 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(UnsupportedFileTypeError): - post_method(api, mock_current_user) + post_method(api, mock_account_context) - def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service): + def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service): api = FileApi() post_method = unwrap(api.post) @@ -266,17 +269,17 @@ class TestFileApiPost: with app.test_request_context(method="POST", data=data): with pytest.raises(BlockedFileExtensionError): - post_method(api, mock_current_user) + post_method(api, mock_account_context) class TestFilePreviewApi: - def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service): + def test_get_preview(self, app: Flask, mock_account_context, mock_file_service): api = FilePreviewApi() get_method = unwrap(api.get) mock_file_service.get_file_preview.return_value = "preview text" with app.test_request_context(): - result = get_method(api, mock_current_tenant_id, "1234") + result = get_method(api, "tenant-123", "1234") assert result == {"content": "preview text"} diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index ae620b1e52..7c7abdcf2d 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -10,6 +10,8 @@ import pytest from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError from controllers.console import remote_files as remote_files_module +from models import Account +from models.account import AccountStatus, TenantAccountRole from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError @@ -20,6 +22,17 @@ def _unwrap(func): return func +def _make_account(account_id: str = "u1") -> Account: + account = Account( + name="Test User", + email=f"{account_id}@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = account_id + account.role = TenantAccountRole.OWNER + return account + + class _FakeResponse: def __init__( self, @@ -48,7 +61,6 @@ def _mock_upload_dependencies( *, file_size_within_limit: bool = True, ): - current_user = SimpleNamespace(id="u1") file_info = SimpleNamespace( filename="report.txt", extension=".txt", @@ -64,6 +76,7 @@ def _mock_upload_dependencies( file_service_cls = MagicMock() file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit monkeypatch.setattr(remote_files_module, "FileService", file_service_cls) + current_user = _make_account() monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr( remote_files_module.file_helpers, @@ -226,7 +239,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"): - handler(api, SimpleNamespace(id="u1")) + handler(api, _make_account()) def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: @@ -243,7 +256,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte with app.test_request_context(method="POST", json={"url": url}): with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"): - handler(api, SimpleNamespace(id="u1")) + handler(api, _make_account()) def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index fe033144d6..6ddb1748d6 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_login import LoginManager, UserMixin +from werkzeug.exceptions import HTTPException from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout from controllers.console.workspace.error import AccountNotInitializedError @@ -17,8 +18,11 @@ from controllers.console.wraps import ( only_edition_enterprise, only_edition_self_hosted, setup_required, + with_current_tenant_id, + with_current_user, ) -from models.account import AccountStatus +from models import Account +from models.account import AccountStatus, TenantAccountRole from services.feature_service import LicenseStatus @@ -33,6 +37,17 @@ class MockUser(UserMixin): return self.id +def make_account(account_id: str = "account-1") -> Account: + account = Account( + name="Test Account", + email=f"{account_id}@example.com", + status=AccountStatus.ACTIVE, + ) + account.id = account_id + account.role = TenantAccountRole.OWNER + return account + + def create_app_with_login(): """Create a Flask app with LoginManager configured.""" app = Flask(__name__) @@ -84,6 +99,42 @@ class TestAccountInitialization: protected_view() +class TestCurrentContextInjection: + """Test request context injection decorators.""" + + def test_with_current_tenant_id_injects_tenant_id(self): + class Handler: + @with_current_tenant_id + def get(self, current_tenant_id: str): + return current_tenant_id + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-123")): + assert Handler().get() == "tenant-123" + + def test_with_current_user_injects_account(self): + current_user = make_account() + + class Handler: + @with_current_user + def get(self, injected_user): + return injected_user + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): + assert Handler().get() is current_user + + def test_stacked_current_context_injectors_preserve_argument_order(self): + current_user = make_account() + + class Handler: + @with_current_user + @with_current_tenant_id + def get(self, current_tenant_id: str, injected_user): + return current_tenant_id, injected_user + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")): + assert Handler().get() == ("tenant-123", current_user) + + class TestEditionChecks: """Test edition-specific decorators""" @@ -114,7 +165,7 @@ class TestEditionChecks: # Act & Assert with app.test_request_context(): with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: cloud_view() assert exc_info.value.code == 404 @@ -177,7 +228,7 @@ class TestBillingEnabled: with app.test_request_context(): with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False): with patch("controllers.console.wraps.FeatureService.get_features") as get_features: - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: billing_view() assert exc_info.value.code == 403 @@ -230,7 +281,7 @@ class TestBillingResourceLimits: return_value=(MockUser("test_user"), "tenant123"), ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: add_member() assert exc_info.value.code == 403 assert "members has reached the limit" in str(exc_info.value.description) @@ -255,7 +306,7 @@ class TestBillingResourceLimits: return_value=(MockUser("test_user"), "tenant123"), ): with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: upload_document() assert exc_info.value.code == 403 @@ -329,7 +380,7 @@ class TestRateLimiting: with patch( "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit ): - with pytest.raises(Exception) as exc_info: + with pytest.raises(HTTPException) as exc_info: knowledge_request() # Verify error