diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 224715d255..fc863b78d7 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -9,9 +9,10 @@ from controllers.common.schema import register_response_schema_models, register_ from controllers.console import console_ns from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource +from controllers.console.wraps import with_current_user from fields.conversation_fields import ResultResponse from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem -from libs.login import current_account_with_tenant +from models import Account from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService @@ -22,8 +23,8 @@ register_response_schema_models(console_ns, ResultResponse) @console_ns.route("/installed-apps//saved-messages", endpoint="installed_app_saved_messages") class SavedMessageListApi(InstalledAppResource): @console_ns.expect(console_ns.models[SavedMessageListQuery.__name__]) - def get(self, installed_app): - current_user, _ = current_account_with_tenant() + @with_current_user + def get(self, current_user: Account, installed_app): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() @@ -46,8 +47,8 @@ class SavedMessageListApi(InstalledAppResource): @console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__]) @console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__]) - def post(self, installed_app): - current_user, _ = current_account_with_tenant() + @with_current_user + def post(self, current_user: Account, installed_app): app_model = installed_app.app if app_model.mode != "completion": raise NotCompletionAppError() @@ -67,8 +68,8 @@ class SavedMessageListApi(InstalledAppResource): ) class SavedMessageApi(InstalledAppResource): @console_ns.response(204, "Saved message deleted successfully") - def delete(self, installed_app, message_id: UUID): - current_user, _ = current_account_with_tenant() + @with_current_user + def delete(self, current_user: Account, installed_app, message_id: UUID): app_model = installed_app.app message_id_str = str(message_id) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index bed8425c35..ebd13e586b 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -13,6 +13,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.error import NotWorkflowAppError from controllers.console.explore.wraps import InstalledAppResource +from controllers.console.wraps import with_current_user from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom @@ -25,7 +26,7 @@ from extensions.ext_redis import redis_client from graphon.graph_engine.manager import GraphEngineManager from graphon.model_runtime.errors.invoke import InvokeError from libs import helper -from libs.login import current_account_with_tenant +from models import Account from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -41,11 +42,11 @@ register_response_schema_models(console_ns, SimpleResultResponse) @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): @console_ns.expect(console_ns.models[WorkflowRunPayload.__name__]) - def post(self, installed_app: InstalledApp): + @with_current_user + def post(self, current_user: Account, installed_app: InstalledApp): """ Run workflow """ - current_user, _ = current_account_with_tenant() app_model = installed_app.app if not app_model: raise NotWorkflowAppError() diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py index f54a6137b3..ee59f3d564 100644 --- a/api/controllers/console/notification.py +++ b/api/controllers/console/notification.py @@ -8,8 +8,14 @@ from pydantic import BaseModel, Field from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models from controllers.console import console_ns -from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required -from libs.login import current_account_with_tenant, login_required +from controllers.console.wraps import ( + account_initialization_required, + only_edition_cloud, + setup_required, + with_current_user, +) +from libs.login import login_required +from models import Account from services.billing_service import BillingService # Notification content is stored under three lang tags. @@ -70,11 +76,10 @@ class NotificationApi(Resource): ) @setup_required @login_required + @with_current_user @account_initialization_required @only_edition_cloud - def get(self): - current_user, _ = current_account_with_tenant() - + def get(self, current_user: Account): result = BillingService.get_account_notification(str(current_user.id)) # Proto JSON uses camelCase field names (Kratos default marshaling). @@ -113,11 +118,11 @@ class NotificationDismissApi(Resource): ) @setup_required @login_required + @with_current_user @account_initialization_required @only_edition_cloud @console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__]) - def post(self): - current_user, _ = current_account_with_tenant() + def post(self, current_user: Account): payload = DismissNotificationPayload.model_validate(request.get_json()) BillingService.dismiss_notification( notification_id=payload.notification_id, diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index f31aa33f16..ad406f2a9e 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -500,3 +500,14 @@ def with_current_tenant_id[T, **P, R]( return view(self, current_tenant_id, *args, **kwargs) return decorated + + +def with_current_user[T, **P, R]( + view: Callable[Concatenate[T, Account, P], R], +) -> Callable[Concatenate[T, P], R]: + @wraps(view) + def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R: + current_user, _ = current_account_with_tenant() + return view(self, current_user, *args, **kwargs) + + return decorated diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py index 00c0d91d1d..07e674afad 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -55,18 +55,20 @@ class TestSavedMessageListApi: has_more=False, data=[make_saved_message(), make_saved_message()], ) + current_user = MagicMock() with ( app.test_request_context("/", query_string={}), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.SavedMessageService, "pagination_by_last_id", return_value=pagination, - ), + ) as pagination_mock, ): - result = method(installed_app) + result = method(api, current_user, installed_app) + pagination_mock.assert_called_once() + assert pagination_mock.call_args.args[1] is current_user assert result["limit"] == 20 assert result["has_more"] is False assert len(result["data"]) == 2 @@ -78,9 +80,8 @@ class TestSavedMessageListApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="chat") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotCompletionAppError): - method(installed_app) + with pytest.raises(NotCompletionAppError): + method(api, MagicMock(), installed_app) def test_post_success(self, app: Flask, payload_patch): api = module.SavedMessageListApi() @@ -90,16 +91,17 @@ class TestSavedMessageListApi: installed_app.app = MagicMock(mode="completion") payload = {"message_id": str(uuid4())} + current_user = MagicMock() with ( app.test_request_context("/", json=payload), payload_patch(payload), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object(module.SavedMessageService, "save") as save_mock, ): - result = method(installed_app) + result = method(api, current_user, installed_app) save_mock.assert_called_once() + assert save_mock.call_args.args[1] is current_user assert result == {"result": "success"} def test_post_message_not_exists(self, app: Flask, payload_patch): @@ -114,7 +116,6 @@ class TestSavedMessageListApi: with ( app.test_request_context("/", json=payload), payload_patch(payload), - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object( module.SavedMessageService, "save", @@ -122,7 +123,7 @@ class TestSavedMessageListApi: ), ): with pytest.raises(NotFound): - method(installed_app) + method(api, MagicMock(), installed_app) class TestSavedMessageApi: @@ -132,14 +133,15 @@ class TestSavedMessageApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="completion") + current_user = MagicMock() with ( - patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), patch.object(module.SavedMessageService, "delete") as delete_mock, ): - result, status = method(installed_app, str(uuid4())) + result, status = method(api, current_user, installed_app, str(uuid4())) delete_mock.assert_called_once() + assert delete_mock.call_args.args[1] is current_user assert status == 204 assert result == "" @@ -150,6 +152,5 @@ class TestSavedMessageApi: installed_app = MagicMock() installed_app.app = MagicMock(mode="chat") - with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): - with pytest.raises(NotCompletionAppError): - method(installed_app, str(uuid4())) + with pytest.raises(NotCompletionAppError): + method(api, MagicMock(), installed_app, str(uuid4())) diff --git a/api/tests/unit_tests/controllers/console/explore/test_workflow.py b/api/tests/unit_tests/controllers/console/explore/test_workflow.py index 3a01925204..c5b2f0bd9b 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/explore/test_workflow.py @@ -61,15 +61,9 @@ class TestInstalledAppWorkflowRunApi: api = InstalledAppWorkflowRunApi() method = unwrap(api.post) - with ( - app.test_request_context("/"), - patch( - "controllers.console.explore.workflow.current_account_with_tenant", - return_value=(MagicMock(), None), - ), - ): + with app.test_request_context("/"): with pytest.raises(NotWorkflowAppError): - method(non_workflow_installed_app) + method(api, MagicMock(), non_workflow_installed_app) def test_success(self, app: Flask, installed_workflow_app, user, payload): api = InstalledAppWorkflowRunApi() @@ -77,18 +71,15 @@ class TestInstalledAppWorkflowRunApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.explore.workflow.current_account_with_tenant", - return_value=(user, None), - ), patch( "controllers.console.explore.workflow.AppGenerateService.generate", return_value=MagicMock(), ) as generate_mock, ): - result = method(installed_workflow_app) + result = method(api, user, installed_workflow_app) generate_mock.assert_called_once() + assert generate_mock.call_args.kwargs["user"] is user assert result is not None def test_rate_limit_error(self, app: Flask, installed_workflow_app, user, payload): @@ -97,17 +88,13 @@ class TestInstalledAppWorkflowRunApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.explore.workflow.current_account_with_tenant", - return_value=(user, None), - ), patch( "controllers.console.explore.workflow.AppGenerateService.generate", side_effect=InvokeRateLimitError("rate limit"), ), ): with pytest.raises(InvokeRateLimitHttpError): - method(installed_workflow_app) + method(api, user, installed_workflow_app) def test_unexpected_exception(self, app: Flask, installed_workflow_app, user, payload): api = InstalledAppWorkflowRunApi() @@ -115,17 +102,13 @@ class TestInstalledAppWorkflowRunApi: with ( app.test_request_context("/", json=payload), - patch( - "controllers.console.explore.workflow.current_account_with_tenant", - return_value=(user, None), - ), patch( "controllers.console.explore.workflow.AppGenerateService.generate", side_effect=Exception("boom"), ), ): with pytest.raises(InternalServerError): - method(installed_workflow_app) + method(api, user, installed_workflow_app) class TestInstalledAppWorkflowTaskStopApi: