mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
chore: inject current user in console handlers (#36628)
This commit is contained in:
parent
135e01930b
commit
87268f0662
@ -9,9 +9,10 @@ from controllers.common.schema import register_response_schema_models, register_
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SavedMessageInfiniteScrollPagination, SavedMessageItem
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
@ -22,8 +23,8 @@ register_response_schema_models(console_ns, ResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/saved-messages", endpoint="installed_app_saved_messages")
|
||||
class SavedMessageListApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[SavedMessageListQuery.__name__])
|
||||
def get(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def get(self, current_user: Account, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -46,8 +47,8 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
|
||||
@console_ns.expect(console_ns.models[SavedMessageCreatePayload.__name__])
|
||||
@console_ns.response(200, "Success", console_ns.models[ResultResponse.__name__])
|
||||
def post(self, installed_app):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app):
|
||||
app_model = installed_app.app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
@ -67,8 +68,8 @@ class SavedMessageListApi(InstalledAppResource):
|
||||
)
|
||||
class SavedMessageApi(InstalledAppResource):
|
||||
@console_ns.response(204, "Saved message deleted successfully")
|
||||
def delete(self, installed_app, message_id: UUID):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@with_current_user
|
||||
def delete(self, current_user: Account, installed_app, message_id: UUID):
|
||||
app_model = installed_app.app
|
||||
|
||||
message_id_str = str(message_id)
|
||||
|
||||
@ -13,6 +13,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotWorkflowAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -25,7 +26,7 @@ from extensions.ext_redis import redis_client
|
||||
from graphon.graph_engine.manager import GraphEngineManager
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.login import current_account_with_tenant
|
||||
from models import Account
|
||||
from models.model import AppMode, InstalledApp
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -41,11 +42,11 @@ register_response_schema_models(console_ns, SimpleResultResponse)
|
||||
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
|
||||
class InstalledAppWorkflowRunApi(InstalledAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunPayload.__name__])
|
||||
def post(self, installed_app: InstalledApp):
|
||||
@with_current_user
|
||||
def post(self, current_user: Account, installed_app: InstalledApp):
|
||||
"""
|
||||
Run workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
app_model = installed_app.app
|
||||
if not app_model:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
@ -8,8 +8,14 @@ from pydantic import BaseModel, Field
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
@ -70,11 +76,10 @@ class NotificationApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
def get(self, current_user: Account):
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
@ -113,11 +118,11 @@ class NotificationDismissApi(Resource):
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@with_current_user
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account):
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
|
||||
@ -500,3 +500,14 @@ def with_current_tenant_id[T, **P, R](
|
||||
return view(self, current_tenant_id, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
def with_current_user[T, **P, R](
|
||||
view: Callable[Concatenate[T, Account, P], R],
|
||||
) -> Callable[Concatenate[T, P], R]:
|
||||
@wraps(view)
|
||||
def decorated(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
current_user, _ = current_account_with_tenant()
|
||||
return view(self, current_user, *args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
@ -55,18 +55,20 @@ class TestSavedMessageListApi:
|
||||
has_more=False,
|
||||
data=[make_saved_message(), make_saved_message()],
|
||||
)
|
||||
current_user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", query_string={}),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.SavedMessageService,
|
||||
"pagination_by_last_id",
|
||||
return_value=pagination,
|
||||
),
|
||||
) as pagination_mock,
|
||||
):
|
||||
result = method(installed_app)
|
||||
result = method(api, current_user, installed_app)
|
||||
|
||||
pagination_mock.assert_called_once()
|
||||
assert pagination_mock.call_args.args[1] is current_user
|
||||
assert result["limit"] == 20
|
||||
assert result["has_more"] is False
|
||||
assert len(result["data"]) == 2
|
||||
@ -78,9 +80,8 @@ class TestSavedMessageListApi:
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app)
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(api, MagicMock(), installed_app)
|
||||
|
||||
def test_post_success(self, app: Flask, payload_patch):
|
||||
api = module.SavedMessageListApi()
|
||||
@ -90,16 +91,17 @@ class TestSavedMessageListApi:
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
|
||||
payload = {"message_id": str(uuid4())}
|
||||
current_user = MagicMock()
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
payload_patch(payload),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(module.SavedMessageService, "save") as save_mock,
|
||||
):
|
||||
result = method(installed_app)
|
||||
result = method(api, current_user, installed_app)
|
||||
|
||||
save_mock.assert_called_once()
|
||||
assert save_mock.call_args.args[1] is current_user
|
||||
assert result == {"result": "success"}
|
||||
|
||||
def test_post_message_not_exists(self, app: Flask, payload_patch):
|
||||
@ -114,7 +116,6 @@ class TestSavedMessageListApi:
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
payload_patch(payload),
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(
|
||||
module.SavedMessageService,
|
||||
"save",
|
||||
@ -122,7 +123,7 @@ class TestSavedMessageListApi:
|
||||
),
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(installed_app)
|
||||
method(api, MagicMock(), installed_app)
|
||||
|
||||
|
||||
class TestSavedMessageApi:
|
||||
@ -132,14 +133,15 @@ class TestSavedMessageApi:
|
||||
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="completion")
|
||||
current_user = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)),
|
||||
patch.object(module.SavedMessageService, "delete") as delete_mock,
|
||||
):
|
||||
result, status = method(installed_app, str(uuid4()))
|
||||
result, status = method(api, current_user, installed_app, str(uuid4()))
|
||||
|
||||
delete_mock.assert_called_once()
|
||||
assert delete_mock.call_args.args[1] is current_user
|
||||
assert status == 204
|
||||
assert result == ""
|
||||
|
||||
@ -150,6 +152,5 @@ class TestSavedMessageApi:
|
||||
installed_app = MagicMock()
|
||||
installed_app.app = MagicMock(mode="chat")
|
||||
|
||||
with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)):
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, str(uuid4()))
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(api, MagicMock(), installed_app, str(uuid4()))
|
||||
|
||||
@ -61,15 +61,9 @@ class TestInstalledAppWorkflowRunApi:
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
with (
|
||||
app.test_request_context("/"),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(MagicMock(), None),
|
||||
),
|
||||
):
|
||||
with app.test_request_context("/"):
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
method(non_workflow_installed_app)
|
||||
method(api, MagicMock(), non_workflow_installed_app)
|
||||
|
||||
def test_success(self, app: Flask, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
@ -77,18 +71,15 @@ class TestInstalledAppWorkflowRunApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
return_value=MagicMock(),
|
||||
) as generate_mock,
|
||||
):
|
||||
result = method(installed_workflow_app)
|
||||
result = method(api, user, installed_workflow_app)
|
||||
|
||||
generate_mock.assert_called_once()
|
||||
assert generate_mock.call_args.kwargs["user"] is user
|
||||
assert result is not None
|
||||
|
||||
def test_rate_limit_error(self, app: Flask, installed_workflow_app, user, payload):
|
||||
@ -97,17 +88,13 @@ class TestInstalledAppWorkflowRunApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
side_effect=InvokeRateLimitError("rate limit"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InvokeRateLimitHttpError):
|
||||
method(installed_workflow_app)
|
||||
method(api, user, installed_workflow_app)
|
||||
|
||||
def test_unexpected_exception(self, app: Flask, installed_workflow_app, user, payload):
|
||||
api = InstalledAppWorkflowRunApi()
|
||||
@ -115,17 +102,13 @@ class TestInstalledAppWorkflowRunApi:
|
||||
|
||||
with (
|
||||
app.test_request_context("/", json=payload),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.current_account_with_tenant",
|
||||
return_value=(user, None),
|
||||
),
|
||||
patch(
|
||||
"controllers.console.explore.workflow.AppGenerateService.generate",
|
||||
side_effect=Exception("boom"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(InternalServerError):
|
||||
method(installed_workflow_app)
|
||||
method(api, user, installed_workflow_app)
|
||||
|
||||
|
||||
class TestInstalledAppWorkflowTaskStopApi:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user