mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
refactor(api): migrate console tags to tenant/user via DI and improve tests (#36658)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
533929d314
commit
59e99ee1ae
@ -28,7 +28,7 @@ from controllers.console.wraps import (
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileResponse, UploadConfig
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
from . import console_ns
|
||||
|
||||
@ -18,7 +18,7 @@ from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from graphon.file import helpers as file_helpers
|
||||
from libs.login import login_required
|
||||
from models.account import Account
|
||||
from models import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
|
||||
@ -9,9 +9,16 @@ from werkzeug.exceptions import Forbidden
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.enums import TagType
|
||||
from services.tag_service import (
|
||||
SaveTagPayload,
|
||||
@ -92,8 +99,8 @@ class TagListApi(Resource):
|
||||
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
|
||||
)
|
||||
@console_ns.doc(responses={200: ("Success", [console_ns.models[TagResponse.__name__]])})
|
||||
def get(self):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
raw_args = request.args.to_dict()
|
||||
param = TagListQueryParam.model_validate(raw_args)
|
||||
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
|
||||
@ -109,9 +116,9 @@ class TagListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
# Allow users with edit permission, or dataset editors (including dataset operators).
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
@ -132,8 +139,8 @@ class TagUpdateDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, tag_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def patch(self, current_user: Account, tag_id: UUID):
|
||||
tag_id_str = str(tag_id)
|
||||
# The role of the current user in the ta table must be admin, owner, or editor
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
@ -163,20 +170,19 @@ class TagUpdateDeleteApi(Resource):
|
||||
return "", 204
|
||||
|
||||
|
||||
def _require_tag_binding_edit_permission() -> None:
|
||||
def _require_tag_binding_edit_permission(current_user: Account) -> None:
|
||||
"""
|
||||
Ensure the current account can edit tag bindings.
|
||||
|
||||
Tag binding operations are allowed for users who can edit resources (app/dataset) within the current tenant.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
||||
|
||||
def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
def _create_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
|
||||
payload = TagBindingPayload.model_validate(console_ns.payload or {})
|
||||
TagService.save_tag_binding(
|
||||
@ -189,8 +195,8 @@ def _create_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
def _remove_tag_bindings() -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission()
|
||||
def _remove_tag_bindings(current_user: Account) -> tuple[dict[str, str], int]:
|
||||
_require_tag_binding_edit_permission(current_user)
|
||||
|
||||
payload = TagBindingRemovePayload.model_validate(console_ns.payload or {})
|
||||
TagService.delete_tag_binding(
|
||||
@ -213,8 +219,9 @@ class TagBindingCollectionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _create_tag_bindings()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _create_tag_bindings(current_user)
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/remove")
|
||||
@ -228,5 +235,6 @@ class TagBindingRemoveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
return _remove_tag_bindings()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account):
|
||||
return _remove_tag_bindings(current_user)
|
||||
|
||||
@ -0,0 +1,116 @@
|
||||
"""Integration tests for console external knowledge API endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.dataset import ExternalKnowledgeApis
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def _create_external_api(
|
||||
db_session: Session,
|
||||
*,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
name: str,
|
||||
) -> ExternalKnowledgeApis:
|
||||
external_api = ExternalKnowledgeApis(
|
||||
tenant_id=tenant_id,
|
||||
created_by=account_id,
|
||||
updated_by=account_id,
|
||||
name=name,
|
||||
description=f"{name} description",
|
||||
settings=json.dumps(
|
||||
{
|
||||
"endpoint": "https://example.com",
|
||||
"api_key": "test-api-key",
|
||||
}
|
||||
),
|
||||
)
|
||||
db_session.add(external_api)
|
||||
db_session.commit()
|
||||
return external_api
|
||||
|
||||
|
||||
def test_external_api_template_list_filters_paginates_and_scopes_to_authenticated_tenant(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
"""Exercise the real list route, including query parsing, DB lookup, and tenant isolation."""
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
foreign_account, foreign_tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
account_id = account.id
|
||||
tenant_id = tenant.id
|
||||
foreign_account_id = foreign_account.id
|
||||
foreign_tenant_id = foreign_tenant.id
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
|
||||
_create_external_api(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
name="Alpha Primary",
|
||||
)
|
||||
_create_external_api(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
name="Alpha Secondary",
|
||||
)
|
||||
_create_external_api(
|
||||
db_session_with_containers,
|
||||
tenant_id=tenant_id,
|
||||
account_id=account_id,
|
||||
name="Beta Unmatched",
|
||||
)
|
||||
_create_external_api(
|
||||
db_session_with_containers,
|
||||
tenant_id=foreign_tenant_id,
|
||||
account_id=foreign_account_id,
|
||||
name="Alpha Foreign",
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/datasets/external-knowledge-api?page=1&limit=1&keyword=Alpha",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json is not None
|
||||
assert response.json["page"] == 1
|
||||
assert response.json["limit"] == 1
|
||||
assert response.json["total"] == 2
|
||||
assert response.json["has_more"] is True
|
||||
assert len(response.json["data"]) == 1
|
||||
|
||||
first_page_item = response.json["data"][0]
|
||||
assert first_page_item["tenant_id"] == tenant_id
|
||||
assert first_page_item["name"] in {"Alpha Primary", "Alpha Secondary"}
|
||||
assert first_page_item["settings"] == {
|
||||
"endpoint": "https://example.com",
|
||||
"api_key": "test-api-key",
|
||||
}
|
||||
assert first_page_item["dataset_bindings"] == []
|
||||
|
||||
second_response = test_client_with_containers.get(
|
||||
"/console/api/datasets/external-knowledge-api?page=2&limit=1&keyword=Alpha",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert second_response.status_code == 200
|
||||
assert second_response.json is not None
|
||||
assert second_response.json["page"] == 2
|
||||
assert second_response.json["limit"] == 1
|
||||
assert second_response.json["total"] == 2
|
||||
assert len(second_response.json["data"]) == 1
|
||||
|
||||
second_page_item = second_response.json["data"][0]
|
||||
assert second_page_item["name"] in {"Alpha Primary", "Alpha Secondary"}
|
||||
assert second_response.json["data"][0]["tenant_id"] == tenant_id
|
||||
@ -0,0 +1,65 @@
|
||||
"""Integration tests for console feature endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.feature_service import FeatureModel, FeatureService, LimitationModel
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_feature_list_returns_current_tenant_configuration_without_vector_space(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
"""Exercise auth, tenant injection, and the feature response shaping contract."""
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
tenant_id = tenant.id
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
feature_model = FeatureModel(
|
||||
members=LimitationModel(size=1, limit=2),
|
||||
apps=LimitationModel(size=3, limit=4),
|
||||
vector_space=LimitationModel(size=5, limit=6),
|
||||
)
|
||||
|
||||
with patch.object(FeatureService, "get_features", return_value=feature_model) as get_features:
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/features",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json is not None
|
||||
assert response.json["members"] == {"size": 1, "limit": 2}
|
||||
assert response.json["apps"] == {"size": 3, "limit": 4}
|
||||
assert "vector_space" not in response.json
|
||||
get_features.assert_called_once_with(tenant_id, exclude_vector_space=True)
|
||||
|
||||
|
||||
def test_feature_vector_space_returns_current_tenant_usage(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
"""Exercise tenant injection and vector-space response serialization through the registered route."""
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
tenant_id = tenant.id
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
|
||||
vector_space = SimpleNamespace(model_dump=lambda: {"size": 0, "limit": 100})
|
||||
|
||||
with patch.object(FeatureService, "get_vector_space", return_value=vector_space) as get_vector_space:
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/features/vector-space",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json == {"size": 0, "limit": 100}
|
||||
get_vector_space.assert_called_once_with(tenant_id)
|
||||
@ -0,0 +1,101 @@
|
||||
"""Integration tests for console file endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from models.model import UploadFile
|
||||
from tests.test_containers_integration_tests.controllers.console.helpers import (
|
||||
authenticate_console_client,
|
||||
create_console_account_and_tenant,
|
||||
)
|
||||
|
||||
|
||||
def test_file_upload_config_returns_console_limits(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
"""Exercise the authenticated upload-config route and response contract."""
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
"/console/api/files/upload",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json == {
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"batch_count_limit": dify_config.UPLOAD_FILE_BATCH_LIMIT,
|
||||
"file_upload_limit": dify_config.BATCH_UPLOAD_LIMIT,
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
"image_file_batch_limit": dify_config.IMAGE_FILE_BATCH_LIMIT,
|
||||
"single_chunk_attachment_limit": dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT,
|
||||
"attachment_image_file_size_limit": dify_config.ATTACHMENT_IMAGE_FILE_SIZE_LIMIT,
|
||||
}
|
||||
|
||||
|
||||
def test_file_upload_persists_file_for_authenticated_current_user(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
"""Exercise real upload behavior plus current-user and tenant propagation."""
|
||||
account, tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
account_id = account.id
|
||||
tenant_id = tenant.id
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
content = b"hello from console integration"
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/files/upload",
|
||||
headers=headers,
|
||||
data={"file": (BytesIO(content), "tenant-owned.txt")},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
assert response.json is not None
|
||||
assert response.json["name"] == "tenant-owned.txt"
|
||||
assert response.json["size"] == len(content)
|
||||
assert response.json["extension"] == "txt"
|
||||
assert response.json["mime_type"] == "text/plain"
|
||||
assert response.json["created_by"] == account_id
|
||||
|
||||
upload_file = db_session_with_containers.scalar(
|
||||
select(UploadFile).where(UploadFile.id == response.json["id"]).limit(1)
|
||||
)
|
||||
assert upload_file is not None
|
||||
assert upload_file.tenant_id == tenant_id
|
||||
assert upload_file.created_by == account_id
|
||||
assert upload_file.name == "tenant-owned.txt"
|
||||
assert upload_file.size == len(content)
|
||||
assert f"/{tenant_id}/" in upload_file.key
|
||||
|
||||
|
||||
def test_file_upload_rejects_missing_file_after_authentication(
|
||||
db_session_with_containers: Session,
|
||||
test_client_with_containers: FlaskClient,
|
||||
) -> None:
|
||||
"""Exercise the route's validation path with a real authenticated account."""
|
||||
account, _tenant = create_console_account_and_tenant(db_session_with_containers)
|
||||
headers = authenticate_console_client(test_client_with_containers, account)
|
||||
|
||||
response = test_client_with_containers.post(
|
||||
"/console/api/files/upload",
|
||||
headers=headers,
|
||||
data={},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json is not None
|
||||
assert response.json["code"] == "no_file_uploaded"
|
||||
@ -70,13 +70,14 @@ class TestExternalApiTemplateListApi:
|
||||
ExternalDatasetService,
|
||||
"get_external_knowledge_apis",
|
||||
return_value=([api_item], 1),
|
||||
),
|
||||
) as get_external_knowledge_apis,
|
||||
):
|
||||
resp, status = method(api, "id")
|
||||
resp, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 1
|
||||
assert resp["data"][0]["id"] == "1"
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
|
||||
def test_post_forbidden(self, app: Flask, current_user):
|
||||
current_user.is_dataset_editor = False
|
||||
@ -321,13 +322,14 @@ class TestExternalApiTemplateListApiAdvanced:
|
||||
patch(
|
||||
"controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis",
|
||||
return_value=(templates, 25),
|
||||
),
|
||||
) as get_external_knowledge_apis,
|
||||
):
|
||||
resp, status = method(api, "id")
|
||||
resp, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp["total"] == 25
|
||||
assert len(resp["data"]) == 3
|
||||
get_external_knowledge_apis.assert_called_once_with(1, 20, "tenant-1", None)
|
||||
|
||||
|
||||
class TestExternalDatasetCreateApiAdvanced:
|
||||
|
||||
@ -13,6 +13,8 @@ from controllers.console.tag.tags import (
|
||||
TagListApi,
|
||||
TagUpdateDeleteApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from models.enums import TagType
|
||||
from services.tag_service import UpdateTagPayload
|
||||
|
||||
@ -35,20 +37,26 @@ def app():
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user():
|
||||
return MagicMock(
|
||||
id="user-1",
|
||||
has_edit_permission=True,
|
||||
is_dataset_editor=True,
|
||||
account = Account(
|
||||
name="Admin User",
|
||||
email="admin@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "user-1"
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def readonly_user():
|
||||
return MagicMock(
|
||||
id="user-2",
|
||||
has_edit_permission=False,
|
||||
is_dataset_editor=False,
|
||||
account = Account(
|
||||
name="Readonly User",
|
||||
email="readonly@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = "user-2"
|
||||
account.role = TenantAccountRole.NORMAL
|
||||
return account
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -80,10 +88,6 @@ class TestTagListApi:
|
||||
|
||||
with app.test_request_context("/?type=knowledge"):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(MagicMock(), "tenant-1"),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.get_tags",
|
||||
return_value=[
|
||||
@ -96,7 +100,7 @@ class TestTagListApi:
|
||||
],
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, "tenant-1")
|
||||
|
||||
assert status == 200
|
||||
assert result == [{"id": "1", "name": "tag", "type": "knowledge", "binding_count": "1"}]
|
||||
@ -109,17 +113,13 @@ class TestTagListApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.save_tags",
|
||||
return_value=tag,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, admin_user)
|
||||
|
||||
assert status == 200
|
||||
assert result["name"] == "test-tag"
|
||||
@ -133,14 +133,10 @@ class TestTagListApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, readonly_user)
|
||||
|
||||
|
||||
class TestTagUpdateDeleteApi:
|
||||
@ -152,10 +148,6 @@ class TestTagUpdateDeleteApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch(
|
||||
"controllers.console.tag.tags.TagService.update_tags",
|
||||
@ -166,7 +158,7 @@ class TestTagUpdateDeleteApi:
|
||||
return_value=3,
|
||||
),
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
result, status = method(api, admin_user, "tag-1")
|
||||
|
||||
assert status == 200
|
||||
update_payload, tag_id = update_tags_mock.call_args.args
|
||||
@ -182,14 +174,10 @@ class TestTagUpdateDeleteApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api, "tag-1")
|
||||
method(api, readonly_user, "tag-1")
|
||||
|
||||
def test_delete_success(self, app: Flask, admin_user):
|
||||
api = TagUpdateDeleteApi()
|
||||
@ -197,10 +185,6 @@ class TestTagUpdateDeleteApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, "tenant-1"),
|
||||
),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock,
|
||||
):
|
||||
result, status = method(api, "tag-1")
|
||||
@ -222,14 +206,10 @@ class TestTagBindingCollectionApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, admin_user)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert status == 200
|
||||
@ -241,14 +221,10 @@ class TestTagBindingCollectionApi:
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, readonly_user)
|
||||
|
||||
|
||||
class TestTagBindingRemoveApi:
|
||||
@ -264,14 +240,10 @@ class TestTagBindingRemoveApi:
|
||||
|
||||
with app.test_request_context("/", json=payload):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(admin_user, None),
|
||||
),
|
||||
payload_patch(payload),
|
||||
patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
result, status = method(api, admin_user)
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
delete_payload = delete_mock.call_args.args[0]
|
||||
@ -285,14 +257,10 @@ class TestTagBindingRemoveApi:
|
||||
|
||||
with app.test_request_context("/", json={}):
|
||||
with (
|
||||
patch(
|
||||
"controllers.console.tag.tags.current_account_with_tenant",
|
||||
return_value=(readonly_user, None),
|
||||
),
|
||||
payload_patch({}),
|
||||
):
|
||||
with pytest.raises(Forbidden):
|
||||
method(api)
|
||||
method(api, readonly_user)
|
||||
|
||||
|
||||
class TestTagResponseModel:
|
||||
|
||||
@ -19,6 +19,8 @@ from controllers.console.files import (
|
||||
FilePreviewApi,
|
||||
FileSupportTypeApi,
|
||||
)
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
|
||||
|
||||
def unwrap(func):
|
||||
@ -53,14 +55,15 @@ def mock_decorators():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user():
|
||||
user = MagicMock()
|
||||
user.is_dataset_editor = True
|
||||
user = Account(name="Test User", email="user-1@example.com", status=AccountStatus.ACTIVE)
|
||||
user.id = "user-1"
|
||||
user.role = TenantAccountRole.OWNER
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_tenant_id():
|
||||
return "tenant-123"
|
||||
def mock_account_context(mock_current_user):
|
||||
return mock_current_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -91,15 +94,15 @@ class TestFileApiGet:
|
||||
|
||||
|
||||
class TestFileApiPost:
|
||||
def test_no_file_uploaded(self, app: Flask, mock_current_user):
|
||||
def test_no_file_uploaded(self, app: Flask, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(method="POST", data={}):
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_too_many_files(self, app: Flask, mock_current_user):
|
||||
def test_too_many_files(self, app: Flask, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -114,9 +117,9 @@ class TestFileApiPost:
|
||||
mock_request.form.get.return_value = None
|
||||
|
||||
with pytest.raises(TooManyFilesError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_filename_missing(self, app: Flask, mock_current_user):
|
||||
def test_filename_missing(self, app: Flask, mock_account_context):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -126,10 +129,10 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FilenameNotExistsError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_dataset_upload_without_permission(self, app: Flask, mock_current_user):
|
||||
mock_current_user.is_dataset_editor = False
|
||||
mock_current_user.role = TenantAccountRole.NORMAL
|
||||
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
@ -143,7 +146,7 @@ class TestFileApiPost:
|
||||
with pytest.raises(Forbidden):
|
||||
post_method(api, mock_current_user)
|
||||
|
||||
def test_successful_upload(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_successful_upload(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -171,13 +174,13 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api, mock_current_user)
|
||||
response, status = post_method(api, mock_account_context)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-123"
|
||||
assert response["name"] == "test.txt"
|
||||
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_upload_with_invalid_source(self, app: Flask, mock_account_context, mock_file_service):
|
||||
"""Test that invalid source parameter gets normalized to None"""
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
@ -208,7 +211,7 @@ class TestFileApiPost:
|
||||
}
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
response, status = post_method(api, mock_current_user)
|
||||
response, status = post_method(api, mock_account_context)
|
||||
|
||||
assert status == 201
|
||||
assert response["id"] == "file-id-456"
|
||||
@ -217,7 +220,7 @@ class TestFileApiPost:
|
||||
call_kwargs = mock_file_service.upload_file.call_args[1]
|
||||
assert call_kwargs["source"] is None
|
||||
|
||||
def test_file_too_large_error(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_file_too_large_error(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -232,9 +235,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(FileTooLargeError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_unsupported_file_type(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_unsupported_file_type(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -249,9 +252,9 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
def test_blocked_extension(self, app: Flask, mock_current_user, mock_file_service):
|
||||
def test_blocked_extension(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FileApi()
|
||||
post_method = unwrap(api.post)
|
||||
|
||||
@ -266,17 +269,17 @@ class TestFileApiPost:
|
||||
|
||||
with app.test_request_context(method="POST", data=data):
|
||||
with pytest.raises(BlockedFileExtensionError):
|
||||
post_method(api, mock_current_user)
|
||||
post_method(api, mock_account_context)
|
||||
|
||||
|
||||
class TestFilePreviewApi:
|
||||
def test_get_preview(self, app: Flask, mock_current_tenant_id, mock_file_service):
|
||||
def test_get_preview(self, app: Flask, mock_account_context, mock_file_service):
|
||||
api = FilePreviewApi()
|
||||
get_method = unwrap(api.get)
|
||||
mock_file_service.get_file_preview.return_value = "preview text"
|
||||
|
||||
with app.test_request_context():
|
||||
result = get_method(api, mock_current_tenant_id, "1234")
|
||||
result = get_method(api, "tenant-123", "1234")
|
||||
|
||||
assert result == {"content": "preview text"}
|
||||
|
||||
|
||||
@ -10,6 +10,8 @@ import pytest
|
||||
|
||||
from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError
|
||||
from controllers.console import remote_files as remote_files_module
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
@ -20,6 +22,17 @@ def _unwrap(func):
|
||||
return func
|
||||
|
||||
|
||||
def _make_account(account_id: str = "u1") -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
email=f"{account_id}@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = account_id
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(
|
||||
self,
|
||||
@ -48,7 +61,6 @@ def _mock_upload_dependencies(
|
||||
*,
|
||||
file_size_within_limit: bool = True,
|
||||
):
|
||||
current_user = SimpleNamespace(id="u1")
|
||||
file_info = SimpleNamespace(
|
||||
filename="report.txt",
|
||||
extension=".txt",
|
||||
@ -64,6 +76,7 @@ def _mock_upload_dependencies(
|
||||
file_service_cls = MagicMock()
|
||||
file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit
|
||||
monkeypatch.setattr(remote_files_module, "FileService", file_service_cls)
|
||||
current_user = _make_account()
|
||||
monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(
|
||||
remote_files_module.file_helpers,
|
||||
@ -226,7 +239,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"):
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
handler(api, _make_account())
|
||||
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@ -243,7 +256,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
|
||||
with app.test_request_context(method="POST", json={"url": url}):
|
||||
with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"):
|
||||
handler(api, SimpleNamespace(id="u1"))
|
||||
handler(api, _make_account())
|
||||
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager, UserMixin
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
@ -17,8 +18,11 @@ from controllers.console.wraps import (
|
||||
only_edition_enterprise,
|
||||
only_edition_self_hosted,
|
||||
setup_required,
|
||||
with_current_tenant_id,
|
||||
with_current_user,
|
||||
)
|
||||
from models.account import AccountStatus
|
||||
from models import Account
|
||||
from models.account import AccountStatus, TenantAccountRole
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
|
||||
@ -33,6 +37,17 @@ class MockUser(UserMixin):
|
||||
return self.id
|
||||
|
||||
|
||||
def make_account(account_id: str = "account-1") -> Account:
|
||||
account = Account(
|
||||
name="Test Account",
|
||||
email=f"{account_id}@example.com",
|
||||
status=AccountStatus.ACTIVE,
|
||||
)
|
||||
account.id = account_id
|
||||
account.role = TenantAccountRole.OWNER
|
||||
return account
|
||||
|
||||
|
||||
def create_app_with_login():
|
||||
"""Create a Flask app with LoginManager configured."""
|
||||
app = Flask(__name__)
|
||||
@ -84,6 +99,42 @@ class TestAccountInitialization:
|
||||
protected_view()
|
||||
|
||||
|
||||
class TestCurrentContextInjection:
|
||||
"""Test request context injection decorators."""
|
||||
|
||||
def test_with_current_tenant_id_injects_tenant_id(self):
|
||||
class Handler:
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
return current_tenant_id
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(), "tenant-123")):
|
||||
assert Handler().get() == "tenant-123"
|
||||
|
||||
def test_with_current_user_injects_account(self):
|
||||
current_user = make_account()
|
||||
|
||||
class Handler:
|
||||
@with_current_user
|
||||
def get(self, injected_user):
|
||||
return injected_user
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
|
||||
assert Handler().get() is current_user
|
||||
|
||||
def test_stacked_current_context_injectors_preserve_argument_order(self):
|
||||
current_user = make_account()
|
||||
|
||||
class Handler:
|
||||
@with_current_user
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str, injected_user):
|
||||
return current_tenant_id, injected_user
|
||||
|
||||
with patch("controllers.console.wraps.current_account_with_tenant", return_value=(current_user, "tenant-123")):
|
||||
assert Handler().get() == ("tenant-123", current_user)
|
||||
|
||||
|
||||
class TestEditionChecks:
|
||||
"""Test edition-specific decorators"""
|
||||
|
||||
@ -114,7 +165,7 @@ class TestEditionChecks:
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
cloud_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
@ -177,7 +228,7 @@ class TestBillingEnabled:
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
billing_view()
|
||||
|
||||
assert exc_info.value.code == 403
|
||||
@ -230,7 +281,7 @@ class TestBillingResourceLimits:
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
add_member()
|
||||
assert exc_info.value.code == 403
|
||||
assert "members has reached the limit" in str(exc_info.value.description)
|
||||
@ -255,7 +306,7 @@ class TestBillingResourceLimits:
|
||||
return_value=(MockUser("test_user"), "tenant123"),
|
||||
):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
upload_document()
|
||||
assert exc_info.value.code == 403
|
||||
|
||||
@ -329,7 +380,7 @@ class TestRateLimiting:
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
knowledge_request()
|
||||
|
||||
# Verify error
|
||||
|
||||
Loading…
Reference in New Issue
Block a user