chore: inject current user in console handlers (#36628)

This commit is contained in:
Tianle 2026-05-25 08:14:08 -05:00 committed by GitHub
parent 135e01930b
commit 87268f0662
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 55 deletions

View File

@ -9,9 +9,10 @@ from controllers.common.schema import register_response_schema_models, register_
from controllers.console import console_ns
from controllers.console.explore.error import NotCompletionAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from fields.conversation_fields import ResultResponse
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
from libs.login import current_account_with_tenant
from models import Account
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
@ -22,8 +23,8 @@ register_response_schema_models(console_ns, ResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
def get(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def get(self, current_user: Account, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -46,8 +47,8 @@ class SavedMessageListApi(InstalledAppResource):
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
def post(self, installed_app):
current_user, _ = current_account_with_tenant()
@with_current_user
def post(self, current_user: Account, installed_app):
app_model = installed_app.app
if app_model.mode != "completion":
raise NotCompletionAppError()
@ -67,8 +68,8 @@ class SavedMessageListApi(InstalledAppResource):
)
class SavedMessageApi(InstalledAppResource):
@console_ns.response(204, "Saved message deleted successfully")
def delete(self, installed_app, message_id: UUID):
current_user, _ = current_account_with_tenant()
@with_current_user
def delete(self, current_user: Account, installed_app, message_id: UUID):
app_model = installed_app.app
message_id_str = str(message_id)

View File

@ -13,6 +13,7 @@ from controllers.console.app.error import (
)
from controllers.console.explore.error import NotWorkflowAppError
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import with_current_user
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
@ -25,7 +26,7 @@ from extensions.ext_redis import redis_client
from graphon.graph_engine.manager import GraphEngineManager
from graphon.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.login import current_account_with_tenant
from models import Account
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -41,11 +42,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
def post(self, installed_app: InstalledApp):
@with_current_user
def post(self, current_user: Account, installed_app: InstalledApp):
"""
Run workflow
"""
current_user, _ = current_account_with_tenant()
app_model = installed_app.app
if not app_model:
raise NotWorkflowAppError()

View File

@ -8,8 +8,14 @@ from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from controllers.console.wraps import (
account_initialization_required,
only_edition_cloud,
setup_required,
with_current_user,
)
from libs.login import login_required
from models import Account
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
@ -70,11 +76,10 @@ class NotificationApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
def get(self, current_user: Account):
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
@ -113,11 +118,11 @@ class NotificationDismissApi(Resource):
)
@setup_required
@login_required
@with_current_user
@account_initialization_required
@only_edition_cloud
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
def post(self):
current_user, _ = current_account_with_tenant()
def post(self, current_user: Account):
payload = DismissNotificationPayload.model_validate(request.get_json())
BillingService.dismiss_notification(
notification_id=payload.notification_id,

View File

@ -500,3 +500,14 @@ def with_current_tenant_id[T, **P, R](
return view(self, current_tenant_id, *args, **kwargs)
return decorated
def with_current_user[T, **P, R](
view: Callable[Concatenate[T, Account, P], R],
) -> Callable[Concatenate[T, P], R]:
@wraps(view)
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
current_user, _ = current_account_with_tenant()
return view(self, current_user, *args, **kwargs)
return decorated

View File

@ -55,18 +55,20 @@ class TestSavedMessageListApi:
has_more=False,
data=[make_saved_message(), make_saved_message()],
)
current_user = MagicMock()
with (
app.test_request_context("/", query_string={}),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.SavedMessageService,
"pagination_by_last_id",
return_value=pagination,
),
) as pagination_mock,
):
result = method(installed_app)
result = method(api, current_user, installed_app)
pagination_mock.assert_called_once()
assert pagination_mock.call_args.args[1] is current_user
assert result["limit"] == 20
assert result["has_more"] is False
assert len(result["data"]) == 2
@ -78,9 +80,8 @@ class TestSavedMessageListApi:
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app)
with pytest.raises(NotCompletionAppError):
method(api, MagicMock(), installed_app)
def test_post_success(self, app: Flask, payload_patch):
api = module.SavedMessageListApi()
@ -90,16 +91,17 @@ class TestSavedMessageListApi:
installed_app.app = MagicMock(mode="completion")
payload = {"message_id": str(uuid4())}
current_user = MagicMock()
with (
app.test_request_context("/", json=payload),
payload_patch(payload),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(module.SavedMessageService, "save") as save_mock,
):
result = method(installed_app)
result = method(api, current_user, installed_app)
save_mock.assert_called_once()
assert save_mock.call_args.args[1] is current_user
assert result == {"result": "success"}
def test_post_message_not_exists(self, app: Flask, payload_patch):
@ -114,7 +116,6 @@ class TestSavedMessageListApi:
with (
app.test_request_context("/", json=payload),
payload_patch(payload),
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(
module.SavedMessageService,
"save",
@ -122,7 +123,7 @@ class TestSavedMessageListApi:
),
):
with pytest.raises(NotFound):
method(installed_app)
method(api, MagicMock(), installed_app)
class TestSavedMessageApi:
@ -132,14 +133,15 @@ class TestSavedMessageApi:
installed_app = MagicMock()
installed_app.app = MagicMock(mode="completion")
current_user = MagicMock()
with (
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
patch.object(module.SavedMessageService, "delete") as delete_mock,
):
result, status = method(installed_app, str(uuid4()))
result, status = method(api, current_user, installed_app, str(uuid4()))
delete_mock.assert_called_once()
assert delete_mock.call_args.args[1] is current_user
assert status == 204
assert result == ""
@ -150,6 +152,5 @@ class TestSavedMessageApi:
installed_app = MagicMock()
installed_app.app = MagicMock(mode="chat")
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
with pytest.raises(NotCompletionAppError):
method(installed_app, str(uuid4()))
with pytest.raises(NotCompletionAppError):
method(api, MagicMock(), installed_app, str(uuid4()))

View File

@ -61,15 +61,9 @@ class TestInstalledAppWorkflowRunApi:
api = InstalledAppWorkflowRunApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(MagicMock(), None),
),
):
with app.test_request_context("/"):
with pytest.raises(NotWorkflowAppError):
method(non_workflow_installed_app)
method(api, MagicMock(), non_workflow_installed_app)
def test_success(self, app: Flask, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
@ -77,18 +71,15 @@ class TestInstalledAppWorkflowRunApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
return_value=MagicMock(),
) as generate_mock,
):
result = method(installed_workflow_app)
result = method(api, user, installed_workflow_app)
generate_mock.assert_called_once()
assert generate_mock.call_args.kwargs["user"] is user
assert result is not None
def test_rate_limit_error(self, app: Flask, installed_workflow_app, user, payload):
@ -97,17 +88,13 @@ class TestInstalledAppWorkflowRunApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
side_effect=InvokeRateLimitError("rate limit"),
),
):
with pytest.raises(InvokeRateLimitHttpError):
method(installed_workflow_app)
method(api, user, installed_workflow_app)
def test_unexpected_exception(self, app: Flask, installed_workflow_app, user, payload):
api = InstalledAppWorkflowRunApi()
@ -115,17 +102,13 @@ class TestInstalledAppWorkflowRunApi:
with (
app.test_request_context("/", json=payload),
patch(
"controllers.console.explore.workflow.current_account_with_tenant",
return_value=(user, None),
),
patch(
"controllers.console.explore.workflow.AppGenerateService.generate",
side_effect=Exception("boom"),
),
):
with pytest.raises(InternalServerError):
method(installed_workflow_app)
method(api, user, installed_workflow_app)
class TestInstalledAppWorkflowTaskStopApi: