refactor(api): migrate console tags to tenant/user via DI and improve tests (#36658)

Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
chariri 2026-05-26 17:20:10 +09:00 committed by GitHub
parent 533929d314
commit 59e99ee1ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 441 additions and 114 deletions

View File

@ -28,7 +28,7 @@ from controllers.console.wraps import (
from extensions.ext_database import db
from fields.file_fields import FileResponse, UploadConfig
from libs.login import login_required
from models.account import Account
from models import Account
from services.file_service import FileService
from . import console_ns

View File

@ -18,7 +18,7 @@ from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from graphon.file import helpers as file_helpers
from libs.login import login_required
from models.account import Account
from models import Account
from services.file_service import FileService

View File

@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_tenant_id,
with_current_user,
)
from fields.base import ResponseModel
from libs.login import current_account_with_tenant, login_required
from libs.login import login_required
from models import Account
from models.enums import TagType
from services.tag_service import (
SaveTagPayload,
@ -92,8 +99,8 @@ class TagListApi(Resource):
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
def get(self):
_, current_tenant_id = current_account_with_tenant()
@with_current_tenant_id
def get(self, current_tenant_id: str):
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
@ -109,9 +116,9 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, or editor
@with_current_user
def post(self, current_user: Account):
# Allow users with edit permission, or dataset editors (including dataset operators).
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
@ -132,8 +139,8 @@ class TagUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, tag_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def patch(self, current_user: Account, tag_id: UUID):
tag_id_str = str(tag_id)
# The role of the current user in the ta table must be admin, owner, or editor
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
@ -163,20 +170,19 @@ class TagUpdateDeleteApi(Resource):
return "", 204
def _require_tag_binding_edit_permission() -> None:
def _require_tag_binding_edit_permission(current_user: Account) -> None:
"""
Ensure the current account can edit tag bindings.
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
"""
current_user, _ = current_account_with_tenant()
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
raise Forbidden()
def _create_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission(current_user)
payload = TagBindingPayload.model_validate(console_ns.payload or {})
TagService.save_tag_binding(
@ -189,8 +195,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
return {"result": "success"}, 200
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission()
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
_require_tag_binding_edit_permission(current_user)
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
TagService.delete_tag_binding(
@ -213,8 +219,9 @@ class TagBindingCollectionApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
return _create_tag_bindings()
@with_current_user
def post(self, current_user: Account):
return _create_tag_bindings(current_user)
@console_ns.route("/tag-bindings/remove")
@ -228,5 +235,6 @@ class TagBindingRemoveApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
return _remove_tag_bindings()
@with_current_user
def post(self, current_user: Account):
return _remove_tag_bindings(current_user)

View File

@ -0,0 +1,116 @@
"""Integration tests for console external knowledge API endpoints."""
from __future__ import annotations
import json
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from models.dataset import ExternalKnowledgeApis
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def _create_external_api(
db_session: Session,
*,
tenant_id: str,
account_id: str,
name: str,
) -> ExternalKnowledgeApis:
external_api = ExternalKnowledgeApis(
tenant_id=tenant_id,
created_by=account_id,
updated_by=account_id,
name=name,
description=f"{name} description",
settings=json.dumps(
{
"endpoint": "https://example.com",
"api_key": "test-api-key",
}
),
)
db_session.add(external_api)
db_session.commit()
return external_api
def test_external_api_template_list_filters_paginates_and_scopes_to_authenticated_tenant(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
"""Exercise the real list route, including query parsing, DB lookup, and tenant isolation."""
account, tenant = create_console_account_and_tenant(db_session_with_containers)
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
account_id = account.id
tenant_id = tenant.id
foreign_account_id = foreign_account.id
foreign_tenant_id = foreign_tenant.id
headers = authenticate_console_client(test_client_with_containers, account)
_create_external_api(
db_session_with_containers,
tenant_id=tenant_id,
account_id=account_id,
name="Alpha Primary",
)
_create_external_api(
db_session_with_containers,
tenant_id=tenant_id,
account_id=account_id,
name="Alpha Secondary",
)
_create_external_api(
db_session_with_containers,
tenant_id=tenant_id,
account_id=account_id,
name="Beta Unmatched",
)
_create_external_api(
db_session_with_containers,
tenant_id=foreign_tenant_id,
account_id=foreign_account_id,
name="Alpha Foreign",
)
response = test_client_with_containers.get(
"/console/api/datasets/external-knowledge-api?page=1&limit=1&keyword=Alpha",
headers=headers,
)
assert response.status_code == 200
assert response.json is not None
assert response.json["page"] == 1
assert response.json["limit"] == 1
assert response.json["total"] == 2
assert response.json["has_more"] is True
assert len(response.json["data"]) == 1
first_page_item = response.json["data"][0]
assert first_page_item["tenant_id"] == tenant_id
assert first_page_item["name"] in {"Alpha Primary", "Alpha Secondary"}
assert first_page_item["settings"] == {
"endpoint": "https://example.com",
"api_key": "test-api-key",
}
assert first_page_item["dataset_bindings"] == []
second_response = test_client_with_containers.get(
"/console/api/datasets/external-knowledge-api?page=2&limit=1&keyword=Alpha",
headers=headers,
)
assert second_response.status_code == 200
assert second_response.json is not None
assert second_response.json["page"] == 2
assert second_response.json["limit"] == 1
assert second_response.json["total"] == 2
assert len(second_response.json["data"]) == 1
second_page_item = second_response.json["data"][0]
assert second_page_item["name"] in {"Alpha Primary", "Alpha Secondary"}
assert second_response.json["data"][0]["tenant_id"] == tenant_id

View File

@ -0,0 +1,65 @@
"""Integration tests for console feature endpoints."""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import patch
from flask.testing import FlaskClient
from sqlalchemy.orm import Session
from services.feature_service import FeatureModel, FeatureService, LimitationModel
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def test_feature_list_returns_current_tenant_configuration_without_vector_space(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
"""Exercise auth, tenant injection, and the feature response shaping contract."""
account, tenant = create_console_account_and_tenant(db_session_with_containers)
tenant_id = tenant.id
headers = authenticate_console_client(test_client_with_containers, account)
feature_model = FeatureModel(
members=LimitationModel(size=1, limit=2),
apps=LimitationModel(size=3, limit=4),
vector_space=LimitationModel(size=5, limit=6),
)
with patch.object(FeatureService, "get_features", return_value=feature_model) as get_features:
response = test_client_with_containers.get(
"/console/api/features",
headers=headers,
)
assert response.status_code == 200
assert response.json is not None
assert response.json["members"] == {"size": 1, "limit": 2}
assert response.json["apps"] == {"size": 3, "limit": 4}
assert "vector_space" not in response.json
get_features.assert_called_once_with(tenant_id, exclude_vector_space=True)
def test_feature_vector_space_returns_current_tenant_usage(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
"""Exercise tenant injection and vector-space response serialization through the registered route."""
account, tenant = create_console_account_and_tenant(db_session_with_containers)
tenant_id = tenant.id
headers = authenticate_console_client(test_client_with_containers, account)
vector_space = SimpleNamespace(model_dump=lambda: {"size": 0, "limit": 100})
with patch.object(FeatureService, "get_vector_space", return_value=vector_space) as get_vector_space:
response = test_client_with_containers.get(
"/console/api/features/vector-space",
headers=headers,
)
assert response.status_code == 200
assert response.json == {"size": 0, "limit": 100}
get_vector_space.assert_called_once_with(tenant_id)

View File

@ -0,0 +1,101 @@
"""Integration tests for console file endpoints."""
from __future__ import annotations
from io import BytesIO
from flask.testing import FlaskClient
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from models.model import UploadFile
from tests.test_containers_integration_tests.controllers.console.helpers import (
authenticate_console_client,
create_console_account_and_tenant,
)
def test_file_upload_config_returns_console_limits(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
"""Exercise the authenticated upload-config route and response contract."""
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
headers = authenticate_console_client(test_client_with_containers, account)
response = test_client_with_containers.get(
"/console/api/files/upload",
headers=headers,
)
assert response.status_code == 200
assert response.json == {
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
}
def test_file_upload_persists_file_for_authenticated_current_user(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
"""Exercise real upload behavior plus current-user and tenant propagation."""
account, tenant = create_console_account_and_tenant(db_session_with_containers)
account_id = account.id
tenant_id = tenant.id
headers = authenticate_console_client(test_client_with_containers, account)
content = b"hello from console integration"
response = test_client_with_containers.post(
"/console/api/files/upload",
headers=headers,
data={"file": (BytesIO(content), "tenant-owned.txt")},
content_type="multipart/form-data",
)
assert response.status_code == 201
assert response.json is not None
assert response.json["name"] == "tenant-owned.txt"
assert response.json["size"] == len(content)
assert response.json["extension"] == "txt"
assert response.json["mime_type"] == "text/plain"
assert response.json["created_by"] == account_id
upload_file = db_session_with_containers.scalar(
select(UploadFile).where(UploadFile.id == response.json["id"]).limit(1)
)
assert upload_file is not None
assert upload_file.tenant_id == tenant_id
assert upload_file.created_by == account_id
assert upload_file.name == "tenant-owned.txt"
assert upload_file.size == len(content)
assert f"/{tenant_id}/" in upload_file.key
def test_file_upload_rejects_missing_file_after_authentication(
db_session_with_containers: Session,
test_client_with_containers: FlaskClient,
) -> None:
"""Exercise the route's validation path with a real authenticated account."""
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
headers = authenticate_console_client(test_client_with_containers, account)
response = test_client_with_containers.post(
"/console/api/files/upload",
headers=headers,
data={},
content_type="multipart/form-data",
)
assert response.status_code == 400
assert response.json is not None
assert response.json["code"] == "no_file_uploaded"

View File

@ -70,13 +70,14 @@ class TestExternalApiTemplateListApi:
ExternalDatasetService,
"get_external_knowledge_apis",
return_value=([api_item], 1),
),
) as get_external_knowledge_apis,
):
resp, status = method(api, "id")
resp, status = method(api, "tenant-1")
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["id"] == "1"
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
def test_post_forbidden(self, app: Flask, current_user):
current_user.is_dataset_editor = False
@ -321,13 +322,14 @@ class TestExternalApiTemplateListApiAdvanced:
patch(
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
return_value=(templates, 25),
),
) as get_external_knowledge_apis,
):
resp, status = method(api, "id")
resp, status = method(api, "tenant-1")
assert status == 200
assert resp["total"] == 25
assert len(resp["data"]) == 3
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
class TestExternalDatasetCreateApiAdvanced:

View File

@ -13,6 +13,8 @@ from controllers.console.tag.tags import (
TagListApi,
TagUpdateDeleteApi,
)
from models import Account
from models.account import AccountStatus, TenantAccountRole
from models.enums import TagType
from services.tag_service import UpdateTagPayload
@ -35,20 +37,26 @@ def app():
@pytest.fixture
def admin_user():
return MagicMock(
id="user-1",
has_edit_permission=True,
is_dataset_editor=True,
account = Account(
name="Admin User",
email="admin@example.com",
status=AccountStatus.ACTIVE,
)
account.id = "user-1"
account.role = TenantAccountRole.OWNER
return account
@pytest.fixture
def readonly_user():
return MagicMock(
id="user-2",
has_edit_permission=False,
is_dataset_editor=False,
account = Account(
name="Readonly User",
email="readonly@example.com",
status=AccountStatus.ACTIVE,
)
account.id = "user-2"
account.role = TenantAccountRole.NORMAL
return account
@pytest.fixture
@ -80,10 +88,6 @@ class TestTagListApi:
with app.test_request_context("/?type=knowledge"):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.tag.tags.TagService.get_tags",
return_value=[
@ -96,7 +100,7 @@ class TestTagListApi:
],
),
):
result, status = method(api)
result, status = method(api, "tenant-1")
assert status == 200
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
@ -109,17 +113,13 @@ class TestTagListApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.save_tags",
return_value=tag,
),
):
result, status = method(api)
result, status = method(api, admin_user)
assert status == 200
assert result["name"] == "test-tag"
@ -133,14 +133,10 @@ class TestTagListApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api)
method(api, readonly_user)
class TestTagUpdateDeleteApi:
@ -152,10 +148,6 @@ class TestTagUpdateDeleteApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch(
"controllers.console.tag.tags.TagService.update_tags",
@ -166,7 +158,7 @@ class TestTagUpdateDeleteApi:
return_value=3,
),
):
result, status = method(api, "tag-1")
result, status = method(api, admin_user, "tag-1")
assert status == 200
update_payload, tag_id = update_tags_mock.call_args.args
@ -182,14 +174,10 @@ class TestTagUpdateDeleteApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch(payload),
):
with pytest.raises(Forbidden):
method(api, "tag-1")
method(api, readonly_user, "tag-1")
def test_delete_success(self, app: Flask, admin_user):
api = TagUpdateDeleteApi()
@ -197,10 +185,6 @@ class TestTagUpdateDeleteApi:
with (
app.test_request_context("/"),
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, "tenant-1"),
),
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
):
result, status = method(api, "tag-1")
@ -222,14 +206,10 @@ class TestTagBindingCollectionApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
):
result, status = method(api)
result, status = method(api, admin_user)
save_mock.assert_called_once()
assert status == 200
@ -241,14 +221,10 @@ class TestTagBindingCollectionApi:
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)
method(api, readonly_user)
class TestTagBindingRemoveApi:
@ -264,14 +240,10 @@ class TestTagBindingRemoveApi:
with app.test_request_context("/", json=payload):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(admin_user, None),
),
payload_patch(payload),
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
):
result, status = method(api)
result, status = method(api, admin_user)
delete_mock.assert_called_once()
delete_payload = delete_mock.call_args.args[0]
@ -285,14 +257,10 @@ class TestTagBindingRemoveApi:
with app.test_request_context("/", json={}):
with (
patch(
"controllers.console.tag.tags.current_account_with_tenant",
return_value=(readonly_user, None),
),
payload_patch({}),
):
with pytest.raises(Forbidden):
method(api)
method(api, readonly_user)
class TestTagResponseModel:

View File

@ -19,6 +19,8 @@ from controllers.console.files import (
FilePreviewApi,
FileSupportTypeApi,
)
from models import Account
from models.account import AccountStatus, TenantAccountRole
def unwrap(func):
@ -53,14 +55,15 @@ def mock_decorators():
@pytest.fixture
def mock_current_user():
user = MagicMock()
user.is_dataset_editor = True
user = Account(name="Test User", email="user-1@example.com", status=AccountStatus.ACTIVE)
user.id = "user-1"
user.role = TenantAccountRole.OWNER
return user
@pytest.fixture
def mock_current_tenant_id():
return "tenant-123"
def mock_account_context(mock_current_user):
return mock_current_user
@pytest.fixture
@ -91,15 +94,15 @@ class TestFileApiGet:
class TestFileApiPost:
def test_no_file_uploaded(self, app: Flask, mock_current_user):
def test_no_file_uploaded(self, app: Flask, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
with app.test_request_context(method="POST", data={}):
with pytest.raises(NoFileUploadedError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_too_many_files(self, app: Flask, mock_current_user):
def test_too_many_files(self, app: Flask, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
@ -114,9 +117,9 @@ class TestFileApiPost:
mock_request.form.get.return_value = None
with pytest.raises(TooManyFilesError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_filename_missing(self, app: Flask, mock_current_user):
def test_filename_missing(self, app: Flask, mock_account_context):
api = FileApi()
post_method = unwrap(api.post)
@ -126,10 +129,10 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(FilenameNotExistsError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
mock_current_user.is_dataset_editor = False
mock_current_user.role = TenantAccountRole.NORMAL
api = FileApi()
post_method = unwrap(api.post)
@ -143,7 +146,7 @@ class TestFileApiPost:
with pytest.raises(Forbidden):
post_method(api, mock_current_user)
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -171,13 +174,13 @@ class TestFileApiPost:
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api, mock_current_user)
response, status = post_method(api, mock_account_context)
assert status == 201
assert response["id"] == "file-id-123"
assert response["name"] == "test.txt"
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
"""Test that invalid source parameter gets normalized to None"""
api = FileApi()
post_method = unwrap(api.post)
@ -208,7 +211,7 @@ class TestFileApiPost:
}
with app.test_request_context(method="POST", data=data):
response, status = post_method(api, mock_current_user)
response, status = post_method(api, mock_account_context)
assert status == 201
assert response["id"] == "file-id-456"
@ -217,7 +220,7 @@ class TestFileApiPost:
call_kwargs = mock_file_service.upload_file.call_args[1]
assert call_kwargs["source"] is None
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -232,9 +235,9 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(FileTooLargeError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -249,9 +252,9 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(UnsupportedFileTypeError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
api = FileApi()
post_method = unwrap(api.post)
@ -266,17 +269,17 @@ class TestFileApiPost:
with app.test_request_context(method="POST", data=data):
with pytest.raises(BlockedFileExtensionError):
post_method(api, mock_current_user)
post_method(api, mock_account_context)
class TestFilePreviewApi:
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
api = FilePreviewApi()
get_method = unwrap(api.get)
mock_file_service.get_file_preview.return_value = "preview text"
with app.test_request_context():
result = get_method(api, mock_current_tenant_id, "1234")
result = get_method(api, "tenant-123", "1234")
assert result == {"content": "preview text"}

View File

@ -10,6 +10,8 @@ import pytest
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
from controllers.console import remote_files as remote_files_module
from models import Account
from models.account import AccountStatus, TenantAccountRole
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
@ -20,6 +22,17 @@ def _unwrap(func):
return func
def _make_account(account_id: str = "u1") -> Account:
account = Account(
name="Test User",
email=f"{account_id}@example.com",
status=AccountStatus.ACTIVE,
)
account.id = account_id
account.role = TenantAccountRole.OWNER
return account
class _FakeResponse:
def __init__(
self,
@ -48,7 +61,6 @@ def _mock_upload_dependencies(
*,
file_size_within_limit: bool = True,
):
current_user = SimpleNamespace(id="u1")
file_info = SimpleNamespace(
filename="report.txt",
extension=".txt",
@ -64,6 +76,7 @@ def _mock_upload_dependencies(
file_service_cls = MagicMock()
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
current_user = _make_account()
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(
remote_files_module.file_helpers,
@ -226,7 +239,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
handler(api, SimpleNamespace(id="u1"))
handler(api, _make_account())
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
@ -243,7 +256,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
with app.test_request_context(method="POST", json={"url": url}):
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
handler(api, SimpleNamespace(id="u1"))
handler(api, _make_account())
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin
from werkzeug.exceptions import HTTPException
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
from controllers.console.workspace.error import AccountNotInitializedError
@ -17,8 +18,11 @@ from controllers.console.wraps import (
only_edition_enterprise,
only_edition_self_hosted,
setup_required,
with_current_tenant_id,
with_current_user,
)
from models.account import AccountStatus
from models import Account
from models.account import AccountStatus, TenantAccountRole
from services.feature_service import LicenseStatus
@ -33,6 +37,17 @@ class MockUser(UserMixin):
return self.id
def make_account(account_id: str = "account-1") -> Account:
account = Account(
name="Test Account",
email=f"{account_id}@example.com",
status=AccountStatus.ACTIVE,
)
account.id = account_id
account.role = TenantAccountRole.OWNER
return account
def create_app_with_login():
"""Create a Flask app with LoginManager configured."""
app = Flask(__name__)
@ -84,6 +99,42 @@ class TestAccountInitialization:
protected_view()
class TestCurrentContextInjection:
"""Test request context injection decorators."""
def test_with_current_tenant_id_injects_tenant_id(self):
class Handler:
@with_current_tenant_id
def get(self, current_tenant_id: str):
return current_tenant_id
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-123")):
assert Handler().get() == "tenant-123"
def test_with_current_user_injects_account(self):
current_user = make_account()
class Handler:
@with_current_user
def get(self, injected_user):
return injected_user
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
assert Handler().get() is current_user
def test_stacked_current_context_injectors_preserve_argument_order(self):
current_user = make_account()
class Handler:
@with_current_user
@with_current_tenant_id
def get(self, current_tenant_id: str, injected_user):
return current_tenant_id, injected_user
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
assert Handler().get() == ("tenant-123", current_user)
class TestEditionChecks:
"""Test edition-specific decorators"""
@ -114,7 +165,7 @@ class TestEditionChecks:
# Act & Assert
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
cloud_view()
assert exc_info.value.code == 404
@ -177,7 +228,7 @@ class TestBillingEnabled:
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False):
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
billing_view()
assert exc_info.value.code == 403
@ -230,7 +281,7 @@ class TestBillingResourceLimits:
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
add_member()
assert exc_info.value.code == 403
assert "members has reached the limit" in str(exc_info.value.description)
@ -255,7 +306,7 @@ class TestBillingResourceLimits:
return_value=(MockUser("test_user"), "tenant123"),
):
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
upload_document()
assert exc_info.value.code == 403
@ -329,7 +380,7 @@ class TestRateLimiting:
with patch(
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
):
with pytest.raises(Exception) as exc_info:
with pytest.raises(HTTPException) as exc_info:
knowledge_request()
# Verify error