mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:23:44 +08:00
refactor: inject current user id in stop message endpoints (#36925)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
6ce61eae59
commit
7056985f72
@ -19,7 +19,12 @@ from controllers.console.app.error import (
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user_id,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -146,14 +151,13 @@ class CompletionMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=AppMode.COMPLETION)
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -234,14 +238,13 @@ class ChatMessageStopApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def post(self, app_model: App, task_id: str):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, app_model: App, task_id: str):
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ from controllers.console.app.error import (
|
||||
)
|
||||
from controllers.console.explore.error import NotChatAppError, NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import with_current_user_id
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
@ -135,20 +136,18 @@ class CompletionApi(InstalledAppResource):
|
||||
)
|
||||
class CompletionStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
if app_model.mode != AppMode.COMPLETION:
|
||||
raise NotCompletionAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=AppMode.value_of(app_model.mode),
|
||||
)
|
||||
|
||||
@ -215,7 +214,8 @@ class ChatApi(InstalledAppResource):
|
||||
)
|
||||
class ChatStopApi(InstalledAppResource):
|
||||
@console_ns.response(200, "Success", console_ns.models[SimpleResultResponse.__name__])
|
||||
def post(self, installed_app: InstalledApp, task_id: str):
|
||||
@with_current_user_id
|
||||
def post(self, current_user_id: str, installed_app: InstalledApp, task_id: str):
|
||||
app_model = installed_app.app
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
@ -223,13 +223,10 @@ class ChatStopApi(InstalledAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
AppTaskService.stop_task(
|
||||
task_id=task_id,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
user_id=current_user.id,
|
||||
user_id=current_user_id,
|
||||
app_mode=app_mode,
|
||||
)
|
||||
|
||||
|
||||
@ -222,17 +222,12 @@ class TestCompletionApi:
|
||||
|
||||
|
||||
class TestCompletionStopApi:
|
||||
def test_stop_success(self, completion_app, user):
|
||||
def test_stop_success(self, completion_app):
|
||||
api = completion_module.CompletionStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user.id = "u1"
|
||||
|
||||
with (
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(completion_module.AppTaskService, "stop_task"),
|
||||
):
|
||||
resp, status = method(completion_app, "task-1")
|
||||
with patch.object(completion_module.AppTaskService, "stop_task"):
|
||||
resp, status = method(api, "u1", completion_app, "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp == {"result": "success"}
|
||||
@ -244,7 +239,7 @@ class TestCompletionStopApi:
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT))
|
||||
|
||||
with pytest.raises(NotCompletionAppError):
|
||||
method(installed_app, "task")
|
||||
method(api, "u1", installed_app, "task")
|
||||
|
||||
|
||||
class TestChatApi:
|
||||
@ -435,17 +430,11 @@ class TestChatApi:
|
||||
|
||||
|
||||
class TestChatStopApi:
|
||||
def test_stop_success(self, chat_app, user):
|
||||
def test_stop_success(self, chat_app):
|
||||
api = completion_module.ChatStopApi()
|
||||
method = unwrap(api.post)
|
||||
|
||||
user.id = "u1"
|
||||
|
||||
with (
|
||||
patch.object(completion_module, "current_user", user),
|
||||
patch.object(completion_module.AppTaskService, "stop_task"),
|
||||
):
|
||||
resp, status = method(chat_app, "task-1")
|
||||
with patch.object(completion_module.AppTaskService, "stop_task"):
|
||||
resp, status = method(api, "u1", chat_app, "task-1")
|
||||
|
||||
assert status == 200
|
||||
assert resp == {"result": "success"}
|
||||
@ -457,4 +446,4 @@ class TestChatStopApi:
|
||||
installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION))
|
||||
|
||||
with pytest.raises(NotChatAppError):
|
||||
method(installed_app, "task")
|
||||
method(api, "u1", installed_app, "task")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user