mirror of
https://github.com/langgenius/dify.git
synced 2026-06-10 18:24:09 +08:00
chore: DI current_user && use inspect (#37084)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
bbdf3d7634
commit
d11e4eeaf7
@ -21,7 +21,12 @@ from controllers.common.schema import (
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
with_current_user,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -54,7 +59,7 @@ from libs import helper
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models import Account, App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
@ -401,13 +406,12 @@ class DraftWorkflowApi(Resource):
|
||||
)
|
||||
@console_ns.response(400, "Invalid workflow configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Sync draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
if "application/json" in content_type:
|
||||
@ -468,13 +472,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
@ -514,12 +517,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
@ -552,12 +555,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow iteration node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
@ -590,12 +593,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -628,12 +631,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow loop node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
@ -695,12 +698,12 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
@ -724,12 +727,12 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service = WorkflowService()
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
@ -753,12 +756,12 @@ class WorkflowDraftHumanInputFormPreviewApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
@ -782,12 +785,12 @@ class WorkflowDraftHumanInputFormRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
@ -811,12 +814,12 @@ class WorkflowDraftHumanInputDeliveryTestApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Test human input delivery
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service.test_human_input_delivery(
|
||||
@ -841,12 +844,12 @@ class DraftWorkflowRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Run draft workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
external_trace_id = get_external_trace_id(request)
|
||||
@ -911,12 +914,12 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
args = args_model.model_dump(exclude_none=True)
|
||||
|
||||
@ -981,12 +984,12 @@ class PublishedWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Publish workflow
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
@ -1083,14 +1086,14 @@ class ConvertToWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Convert basic mode of chatbot app to workflow mode
|
||||
Convert expert mode of chatbot app to workflow mode
|
||||
Convert Completion App to Workflow App
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = console_ns.payload or {}
|
||||
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
|
||||
@ -1122,9 +1125,9 @@ class WorkflowFeaturesApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
@ -1150,12 +1153,12 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
def get(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Get published workflows
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
page = args.page
|
||||
@ -1199,9 +1202,9 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
def post(self, current_user: Account, app_model: App, workflow_id: str):
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
try:
|
||||
@ -1237,12 +1240,12 @@ class WorkflowByIdApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
def patch(self, current_user: Account, app_model: App, workflow_id: str):
|
||||
"""
|
||||
Update workflow attributes
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# Prepare update data
|
||||
@ -1355,12 +1358,12 @@ class DraftWorkflowTriggerRunApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Poll for trigger events and execute full workflow when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
|
||||
node_id = args.node_id
|
||||
workflow_service = WorkflowService()
|
||||
@ -1419,12 +1422,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
def post(self, current_user: Account, app_model: App, node_id: str):
|
||||
"""
|
||||
Poll for trigger events and execute single node when event arrives
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.get_draft_workflow(app_model)
|
||||
@ -1499,12 +1502,12 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@with_current_user
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
def post(self, current_user: Account, app_model: App):
|
||||
"""
|
||||
Full workflow debug when the start node is a trigger
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
|
||||
node_ids = args.node_ids
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@ -68,15 +69,6 @@ from tests.test_containers_integration_tests.controllers.console.helpers import
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="tester",
|
||||
@ -108,7 +100,7 @@ class TestCompletionEndpoints:
|
||||
|
||||
def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
@ -125,13 +117,13 @@ class TestCompletionEndpoints:
|
||||
"/",
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
resp = method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
resp = method(api, _make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert resp == {"result": {"text": "ok"}}
|
||||
|
||||
def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
@ -146,11 +138,11 @@ class TestCompletionEndpoints:
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
method(api, _make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
@ -163,11 +155,11 @@ class TestCompletionEndpoints:
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderNotInitializeError):
|
||||
method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
method(api, _make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = completion_module.CompletionMessageApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
completion_module.AppGenerateService,
|
||||
@ -180,7 +172,7 @@ class TestCompletionEndpoints:
|
||||
json={"inputs": {}, "model_config": {}, "query": "hi"},
|
||||
):
|
||||
with pytest.raises(completion_module.ProviderQuotaExceededError):
|
||||
method(_make_account(), app_model=MagicMock(id="app-1"))
|
||||
method(api, _make_account(), app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
class TestAppEndpoints:
|
||||
@ -190,7 +182,7 @@ class TestAppEndpoints:
|
||||
|
||||
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = app_module.AppApi()
|
||||
method = _unwrap(api.put)
|
||||
method = unwrap(api.put)
|
||||
payload = {
|
||||
"name": "Updated App",
|
||||
"description": "Updated description",
|
||||
@ -209,7 +201,7 @@ class TestAppEndpoints:
|
||||
app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI))
|
||||
response = method(api, app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI))
|
||||
|
||||
assert response == {"id": "app-1"}
|
||||
assert app_service.update_app.call_args.args[1]["icon_type"] is None
|
||||
@ -228,7 +220,7 @@ class TestAppEndpoints:
|
||||
|
||||
def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = app_module.AppIconApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
payload = {
|
||||
"icon": "https://example.com/icon.png",
|
||||
"icon_type": "image",
|
||||
@ -246,7 +238,7 @@ class TestAppEndpoints:
|
||||
app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload),
|
||||
patch.object(type(console_ns), "payload", payload),
|
||||
):
|
||||
response = method(app_model=SimpleNamespace())
|
||||
response = method(api, app_model=SimpleNamespace())
|
||||
|
||||
assert response == {"id": "app-1"}
|
||||
assert app_service.update_app_icon.call_args.args[1:] == (
|
||||
@ -300,7 +292,7 @@ class TestOpsTraceEndpoints:
|
||||
|
||||
def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
@ -309,13 +301,13 @@ class TestOpsTraceEndpoints:
|
||||
)
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
result = method(app_model=MagicMock(id="app-1"))
|
||||
result = method(api, app_model=MagicMock(id="app-1"))
|
||||
|
||||
assert result == {"has_not_configured": True}
|
||||
|
||||
def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
@ -328,11 +320,11 @@ class TestOpsTraceEndpoints:
|
||||
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
method(api, app_model=MagicMock(id="app-1"))
|
||||
|
||||
def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = ops_trace_module.TraceAppConfigApi()
|
||||
method = _unwrap(api.delete)
|
||||
method = unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(
|
||||
ops_trace_module.OpsService,
|
||||
@ -342,7 +334,7 @@ class TestOpsTraceEndpoints:
|
||||
|
||||
with app.test_request_context("/?tracing_provider=langfuse"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(app_model=MagicMock(id="app-1"))
|
||||
method(api, app_model=MagicMock(id="app-1"))
|
||||
|
||||
|
||||
class TestSiteEndpoints:
|
||||
@ -360,7 +352,7 @@ class TestSiteEndpoints:
|
||||
|
||||
def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = site_module.AppSite()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
site.app_id = "app-1"
|
||||
@ -386,14 +378,14 @@ class TestSiteEndpoints:
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/", json={"title": "My Site"}):
|
||||
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(api, SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["title"] == "My Site"
|
||||
|
||||
def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = site_module.AppSiteAccessTokenReset()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
site = MagicMock()
|
||||
site.app_id = "app-1"
|
||||
@ -420,7 +412,7 @@ class TestSiteEndpoints:
|
||||
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
|
||||
|
||||
with app.test_request_context("/"):
|
||||
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(api, SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["access_token"] == "code"
|
||||
@ -451,7 +443,7 @@ class TestWorkflowAppLogEndpoints:
|
||||
|
||||
def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = workflow_app_log_module.WorkflowAppLogApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
@ -481,7 +473,7 @@ class TestWorkflowAppLogEndpoints:
|
||||
)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
@ -497,7 +489,7 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
|
||||
def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
@ -532,7 +524,7 @@ class TestWorkflowDraftVariableEndpoints:
|
||||
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
|
||||
|
||||
with app.test_request_context("/?page=1&limit=20"):
|
||||
result = method(_make_account(), app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(api, _make_account(), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert result == {"items": [], "total": 0}
|
||||
|
||||
@ -565,10 +557,12 @@ class TestWorkflowStatisticEndpoints:
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
response = method(
|
||||
api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")
|
||||
)
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
|
||||
|
||||
@ -588,10 +582,12 @@ class TestWorkflowStatisticEndpoints:
|
||||
)
|
||||
|
||||
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
|
||||
response = method(
|
||||
api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")
|
||||
)
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
|
||||
|
||||
@ -610,7 +606,7 @@ class TestWorkflowTriggerEndpoints:
|
||||
|
||||
def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
api = workflow_trigger_module.WebhookTriggerApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
@ -635,7 +631,7 @@ class TestWorkflowTriggerEndpoints:
|
||||
monkeypatch.setattr(workflow_trigger_module, "sessionmaker", DummySessionMaker)
|
||||
|
||||
with app.test_request_context("/?node_id=node-1"):
|
||||
result = method(app_model=SimpleNamespace(id="app-1"))
|
||||
result = method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys())
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -12,15 +13,6 @@ from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
@ -42,7 +34,7 @@ class TestAppImportApi:
|
||||
|
||||
def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
@ -52,14 +44,14 @@ class TestAppImportApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
@ -69,14 +61,14 @@ class TestAppImportApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
monkeypatch.setattr(
|
||||
@ -88,7 +80,7 @@ class TestAppImportApi:
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
@ -96,7 +88,7 @@ class TestAppImportApi:
|
||||
|
||||
def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
@ -111,7 +103,7 @@ class TestAppImportApi:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
fake_session.commit.assert_called_once_with()
|
||||
fake_session.rollback.assert_not_called()
|
||||
@ -120,7 +112,7 @@ class TestAppImportApi:
|
||||
|
||||
def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
@ -135,7 +127,7 @@ class TestAppImportApi:
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
fake_session.rollback.assert_called_once_with()
|
||||
fake_session.commit.assert_not_called()
|
||||
@ -150,7 +142,7 @@ class TestAppImportConfirmApi:
|
||||
|
||||
def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportConfirmApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
@ -159,7 +151,7 @@ class TestAppImportConfirmApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
|
||||
response, status = method(api, SimpleNamespace(id="u1"), import_id="import-1")
|
||||
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
@ -172,7 +164,7 @@ class TestAppImportCheckDependenciesApi:
|
||||
|
||||
def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportCheckDependenciesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
@ -181,7 +173,7 @@ class TestAppImportCheckDependenciesApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
|
||||
response, status = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response, status = method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert status == 200
|
||||
assert response["leaked_dependencies"] == []
|
||||
|
||||
@ -239,13 +239,7 @@ class TestTagUnbindingPayload:
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unwrap(method):
|
||||
"""Walk ``__wrapped__`` chain to get the original function."""
|
||||
fn = method
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
return fn
|
||||
from inspect import unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -499,7 +493,7 @@ class TestDatasetListApiPost:
|
||||
json={"name": "New Dataset"},
|
||||
):
|
||||
api = DatasetListApi()
|
||||
response, status = _unwrap(api.post)(api, tenant_id=mock_tenant.id)
|
||||
response, status = unwrap(api.post)(api, tenant_id=mock_tenant.id)
|
||||
|
||||
assert status == 200
|
||||
assert_dataset_detail_shape(response)
|
||||
@ -527,7 +521,7 @@ class TestDatasetListApiPost:
|
||||
):
|
||||
api = DatasetListApi()
|
||||
with pytest.raises(DatasetNameDuplicateError):
|
||||
_unwrap(api.post)(api, tenant_id=mock_tenant.id)
|
||||
unwrap(api.post)(api, tenant_id=mock_tenant.id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -720,7 +714,7 @@ class TestDatasetApiPatch:
|
||||
json=payload,
|
||||
):
|
||||
api = DatasetApi()
|
||||
response, status = _unwrap(api.patch)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
response, status = unwrap(api.patch)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
|
||||
assert status == 200
|
||||
assert_dataset_detail_shape(response, with_partial_members=True)
|
||||
@ -760,7 +754,7 @@ class TestDatasetApiDelete:
|
||||
method="DELETE",
|
||||
):
|
||||
api = DatasetApi()
|
||||
result = _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
result = unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
|
||||
assert result == ("", 204)
|
||||
|
||||
@ -783,7 +777,7 @@ class TestDatasetApiDelete:
|
||||
):
|
||||
api = DatasetApi()
|
||||
with pytest.raises(NotFound):
|
||||
_unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
|
||||
@patch("controllers.service_api.dataset.dataset.current_user")
|
||||
@patch("controllers.service_api.dataset.dataset.DatasetService")
|
||||
@ -804,7 +798,7 @@ class TestDatasetApiDelete:
|
||||
):
|
||||
api = DatasetApi()
|
||||
with pytest.raises(DatasetInUseError):
|
||||
_unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -19,11 +19,7 @@ def app(flask_app_with_containers) -> Flask:
|
||||
return flask_app_with_containers
|
||||
|
||||
|
||||
def _unwrap(method):
|
||||
fn = method
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
return fn
|
||||
from inspect import unwrap
|
||||
|
||||
|
||||
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
|
||||
@ -76,7 +72,7 @@ class TestAppSiteApi:
|
||||
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
|
||||
api = AppSiteApi()
|
||||
response = _unwrap(api.get)(api, app_model=app_model)
|
||||
response = unwrap(api.get)(api, app_model=app_model)
|
||||
|
||||
assert response["title"] == "Service API Site"
|
||||
assert response["icon"] == "robot"
|
||||
@ -89,7 +85,7 @@ class TestAppSiteApi:
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
_unwrap(api.get)(api, app_model=app_model)
|
||||
unwrap(api.get)(api, app_model=app_model)
|
||||
|
||||
def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None:
|
||||
tenant = _create_tenant(db_session_with_containers)
|
||||
@ -107,4 +103,4 @@ class TestAppSiteApi:
|
||||
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
|
||||
api = AppSiteApi()
|
||||
with pytest.raises(Forbidden):
|
||||
_unwrap(api.get)(api, app_model=app_model)
|
||||
unwrap(api.get)(api, app_model=app_model)
|
||||
|
||||
@ -62,7 +62,7 @@ def _create_app_with_site(session: Session) -> tuple[App, Account]:
|
||||
tenant_id=tenant.id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type="emoji",
|
||||
icon="app",
|
||||
icon_background="#ffffff",
|
||||
|
||||
@ -205,7 +205,7 @@ class TestHumanInputResumeNodeExecutionIntegration:
|
||||
tenant_id=tenant.id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type=IconType.EMOJI.value,
|
||||
icon="rocket",
|
||||
icon_background="#4ECDC4",
|
||||
|
||||
@ -75,7 +75,7 @@ def _app_stub(**overrides: Any) -> App:
|
||||
defaults = {
|
||||
"id": str(uuid4()),
|
||||
"tenant_id": _DEFAULT_TENANT_ID,
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"mode": AppMode.WORKFLOW,
|
||||
"name": "n",
|
||||
"description": "d",
|
||||
"icon_type": IconType.EMOJI,
|
||||
@ -528,7 +528,7 @@ class TestAppDslService:
|
||||
|
||||
created_app = SimpleNamespace(
|
||||
id=str(uuid4()),
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
tenant_id=_DEFAULT_TENANT_ID,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
@ -707,7 +707,7 @@ class TestAppDslService:
|
||||
)
|
||||
|
||||
app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
name="old",
|
||||
description="old-desc",
|
||||
icon_type=IconType.EMOJI,
|
||||
@ -721,7 +721,7 @@ class TestAppDslService:
|
||||
app=app,
|
||||
data={
|
||||
"app": {
|
||||
"mode": AppMode.WORKFLOW.value,
|
||||
"mode": AppMode.WORKFLOW,
|
||||
"name": "yaml-name",
|
||||
"icon_type": IconType.IMAGE,
|
||||
"icon": "X",
|
||||
@ -747,7 +747,7 @@ class TestAppDslService:
|
||||
with pytest.raises(ValueError, match="Current tenant is not set"):
|
||||
service._create_or_update_app(
|
||||
app=None,
|
||||
data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}},
|
||||
data={"app": {"mode": AppMode.WORKFLOW, "name": "n"}},
|
||||
account=account,
|
||||
)
|
||||
|
||||
@ -772,7 +772,7 @@ class TestAppDslService:
|
||||
)
|
||||
]
|
||||
data = {
|
||||
"app": {"mode": AppMode.WORKFLOW.value, "name": "n"},
|
||||
"app": {"mode": AppMode.WORKFLOW, "name": "n"},
|
||||
"workflow": {
|
||||
"graph": {"nodes": []},
|
||||
"features": {},
|
||||
@ -792,8 +792,8 @@ class TestAppDslService:
|
||||
service = AppDslService(db_session_with_containers)
|
||||
with pytest.raises(ValueError, match="Missing workflow data"):
|
||||
service._create_or_update_app(
|
||||
app=_app_stub(mode=AppMode.WORKFLOW.value),
|
||||
data={"app": {"mode": AppMode.WORKFLOW.value}},
|
||||
app=_app_stub(mode=AppMode.WORKFLOW),
|
||||
data={"app": {"mode": AppMode.WORKFLOW}},
|
||||
account=_account_mock(),
|
||||
)
|
||||
|
||||
@ -852,7 +852,7 @@ class TestAppDslService:
|
||||
)
|
||||
|
||||
workflow_app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type="emoji",
|
||||
)
|
||||
AppDslService.export_dsl(workflow_app)
|
||||
@ -874,7 +874,7 @@ class TestAppDslService:
|
||||
)
|
||||
|
||||
emoji_app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
name="Emoji App",
|
||||
icon="🎨",
|
||||
icon_type=IconType.EMOJI,
|
||||
@ -889,7 +889,7 @@ class TestAppDslService:
|
||||
assert data["app"]["icon_background"] == "#FF5733"
|
||||
|
||||
image_app = _app_stub(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
name="Image App",
|
||||
icon="https://example.com/icon.png",
|
||||
icon_type=IconType.IMAGE,
|
||||
|
||||
@ -50,7 +50,7 @@ def _create_app_with_draft_workflow(
|
||||
tenant_id=tenant.id,
|
||||
name="Test App",
|
||||
description="",
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
mode=AppMode.WORKFLOW,
|
||||
icon_type="emoji",
|
||||
icon="app",
|
||||
icon_background="#ffffff",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from typing import Protocol, cast
|
||||
|
||||
@ -26,12 +27,6 @@ from controllers.console.agent.roster import (
|
||||
from services.entities.agent_entities import ComposerSaveStrategy, ComposerVariant
|
||||
|
||||
|
||||
def _unwrap(method):
|
||||
while hasattr(method, "__wrapped__"):
|
||||
method = method.__wrapped__
|
||||
return method
|
||||
|
||||
|
||||
def _agent_response(agent_id: str = "agent-1") -> dict:
|
||||
return {
|
||||
"id": agent_id,
|
||||
@ -135,7 +130,7 @@ def test_roster_list_get_parses_query_and_calls_service(app: Flask, monkeypatch:
|
||||
monkeypatch.setattr(roster_controller.AgentRosterService, "list_roster_agents", list_roster_agents)
|
||||
|
||||
with app.test_request_context("/console/api/agents?page=2&limit=5&keyword=analyst"):
|
||||
result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1")
|
||||
result = unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1")
|
||||
|
||||
assert result["page"] == 2
|
||||
assert captured == {"tenant_id": "tenant-1", "page": 2, "limit": 5, "keyword": "analyst"}
|
||||
@ -157,7 +152,7 @@ def test_roster_list_post_creates_agent_and_returns_detail(
|
||||
)
|
||||
|
||||
with app.test_request_context(json={"name": "Analyst", "agent_soul": {"prompt": {"system_prompt": "x"}}}):
|
||||
result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id)
|
||||
result, status = unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id)
|
||||
|
||||
assert status == 201
|
||||
assert result["id"] == "agent-1"
|
||||
@ -174,7 +169,7 @@ def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.Monkey
|
||||
monkeypatch.setattr(roster_controller.AgentRosterService, "list_invite_options", list_invite_options)
|
||||
|
||||
with app.test_request_context("/console/api/agents/invite-options?page=1&limit=10&app_id=app-1"):
|
||||
result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1")
|
||||
result = unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1")
|
||||
|
||||
assert result == {"data": [], "page": 1, "limit": 10, "total": 0, "has_more": False}
|
||||
assert captured == {"tenant_id": "tenant-1", "page": 1, "limit": 10, "keyword": None, "app_id": "app-1"}
|
||||
@ -232,19 +227,19 @@ def test_roster_detail_patch_delete_and_versions_call_services(
|
||||
},
|
||||
)
|
||||
|
||||
assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id
|
||||
assert unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id
|
||||
with app.test_request_context(json={"description": "updated"}):
|
||||
assert (
|
||||
_unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"]
|
||||
unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"]
|
||||
== "updated"
|
||||
)
|
||||
assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204)
|
||||
assert unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204)
|
||||
assert archived["account_id"] == "account-1"
|
||||
assert (
|
||||
_unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
|
||||
unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
|
||||
== "version-1"
|
||||
)
|
||||
version_detail = _unwrap(AgentRosterVersionDetailApi.get)(
|
||||
version_detail = unwrap(AgentRosterVersionDetailApi.get)(
|
||||
AgentRosterVersionDetailApi(), "tenant-1", agent_id, version_id
|
||||
)
|
||||
assert version_detail["id"] == version_id
|
||||
@ -286,27 +281,27 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(
|
||||
},
|
||||
)
|
||||
|
||||
workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1")
|
||||
workflow_state = unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1")
|
||||
assert workflow_state["node_id"] == "node-1"
|
||||
with app.test_request_context(json=payload):
|
||||
saved_state = _unwrap(WorkflowAgentComposerApi.put)(
|
||||
saved_state = unwrap(WorkflowAgentComposerApi.put)(
|
||||
WorkflowAgentComposerApi(), "tenant-1", account_id, app_model, "node-1"
|
||||
)
|
||||
assert saved_state["save_options"] == ["node_job_only"]
|
||||
assert _unwrap(WorkflowAgentComposerValidateApi.post)(
|
||||
assert unwrap(WorkflowAgentComposerValidateApi.post)(
|
||||
WorkflowAgentComposerValidateApi(), app_model, "node-1"
|
||||
) == {"result": "success", "errors": []}
|
||||
assert (
|
||||
_unwrap(WorkflowAgentComposerCandidatesApi.get)(WorkflowAgentComposerCandidatesApi(), app_model, "node-1")[
|
||||
unwrap(WorkflowAgentComposerCandidatesApi.get)(WorkflowAgentComposerCandidatesApi(), app_model, "node-1")[
|
||||
"variant"
|
||||
]
|
||||
== "workflow"
|
||||
)
|
||||
with app.test_request_context(json=payload):
|
||||
assert _unwrap(WorkflowAgentComposerImpactApi.post)(
|
||||
assert unwrap(WorkflowAgentComposerImpactApi.post)(
|
||||
WorkflowAgentComposerImpactApi(), "tenant-1", app_model, "node-1"
|
||||
) == {"current_snapshot_id": "version-1", "workflow_node_count": 1, "bindings": []}
|
||||
assert _unwrap(WorkflowAgentComposerSaveToRosterApi.post)(
|
||||
assert unwrap(WorkflowAgentComposerSaveToRosterApi.post)(
|
||||
WorkflowAgentComposerSaveToRosterApi(), "tenant-1", account_id, app_model, "node-1"
|
||||
)["save_options"] == ["node_job_only"]
|
||||
|
||||
@ -315,7 +310,7 @@ def test_workflow_impact_returns_empty_without_version(app: Flask) -> None:
|
||||
payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value}
|
||||
|
||||
with app.test_request_context(json=payload):
|
||||
result = _unwrap(WorkflowAgentComposerImpactApi.post)(
|
||||
result = unwrap(WorkflowAgentComposerImpactApi.post)(
|
||||
WorkflowAgentComposerImpactApi(), "tenant-1", SimpleNamespace(id="app-1"), "node-1"
|
||||
)
|
||||
|
||||
@ -348,15 +343,15 @@ def test_agent_app_composer_get_put_validate_and_candidates(
|
||||
lambda **kwargs: _candidates_response("agent_app"),
|
||||
)
|
||||
|
||||
assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app"
|
||||
assert unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app"
|
||||
with app.test_request_context(json=payload):
|
||||
assert (
|
||||
_unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"]
|
||||
unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"]
|
||||
== "agent_app"
|
||||
)
|
||||
assert _unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == {
|
||||
assert unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == {
|
||||
"result": "success",
|
||||
"errors": [],
|
||||
}
|
||||
agent_app_candidates = _unwrap(AgentAppComposerCandidatesApi.get)(AgentAppComposerCandidatesApi(), app_model)
|
||||
agent_app_candidates = unwrap(AgentAppComposerCandidatesApi.get)(AgentAppComposerCandidatesApi(), app_model)
|
||||
assert agent_app_candidates["variant"] == "agent_app"
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -12,15 +13,6 @@ from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
@ -52,7 +44,7 @@ class TestAppImportApi:
|
||||
def test_import_post_returns_failed_status_and_rolls_back(
|
||||
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
session = _mock_session(monkeypatch)
|
||||
@ -63,7 +55,7 @@ class TestAppImportApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
@ -73,7 +65,7 @@ class TestAppImportApi:
|
||||
def test_import_post_returns_pending_status_and_commits(
|
||||
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
session = _mock_session(monkeypatch)
|
||||
@ -84,7 +76,7 @@ class TestAppImportApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
@ -94,7 +86,7 @@ class TestAppImportApi:
|
||||
def test_import_post_updates_webapp_auth_when_enabled(
|
||||
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
session = _mock_session(monkeypatch)
|
||||
@ -107,7 +99,7 @@ class TestAppImportApi:
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method(SimpleNamespace(id="u1"))
|
||||
response, status = method(api, SimpleNamespace(id="u1"))
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
@ -124,7 +116,7 @@ class TestAppImportConfirmApi:
|
||||
def test_import_confirm_returns_failed_status_and_rolls_back(
|
||||
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
@ -134,7 +126,7 @@ class TestAppImportConfirmApi:
|
||||
)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
|
||||
response, status = method(api, SimpleNamespace(id="u1"), import_id="import-1")
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
@ -32,27 +34,18 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
|
||||
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_console_audio_api_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
response = handler(app_model=app_model)
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"text": "ok"}
|
||||
|
||||
@ -71,33 +64,33 @@ def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None
|
||||
(InvokeError("invoke"), CompletionRequestError),
|
||||
],
|
||||
)
|
||||
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
def test_console_audio_api_error_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(expected):
|
||||
handler(app_model=app_model)
|
||||
handler(api, app_model=app_model)
|
||||
|
||||
|
||||
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_console_audio_api_unhandled_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
|
||||
api = ChatMessageAudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
|
||||
with pytest.raises(InternalServerError):
|
||||
handler(app_model=app_model)
|
||||
handler(api, app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_console_text_api_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
@ -105,16 +98,16 @@ def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method="POST",
|
||||
json={"text": "hello", "voice": "v"},
|
||||
):
|
||||
response = handler(app_model=app_model)
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_console_text_api_error_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
|
||||
|
||||
api = ChatMessageTextApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context(
|
||||
@ -123,23 +116,23 @@ def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
json={"text": "hello"},
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
handler(app_model=app_model)
|
||||
handler(api, app_model=app_model)
|
||||
|
||||
|
||||
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_console_text_modes_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
response = handler(app_model=app_model)
|
||||
response = handler(api, app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_console_text_modes_language_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
"transcript_tts_voices",
|
||||
@ -147,17 +140,17 @@ def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch)
|
||||
)
|
||||
|
||||
api = TextModesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(tenant_id="t1")
|
||||
|
||||
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
|
||||
with pytest.raises(AppUnavailableError):
|
||||
handler(app_model=app_model)
|
||||
handler(api, app_model=app_model)
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_audio_to_text_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
@ -171,14 +164,14 @@ def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_audio_to_text_maps_audio_too_large(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
@ -196,12 +189,12 @@ def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
method(api, app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
@ -212,14 +205,14 @@ def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_voices_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
@ -230,14 +223,14 @@ def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> N
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_audio_to_text_with_invalid_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
@ -251,13 +244,13 @@ def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_with_language_param(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
@ -268,13 +261,13 @@ def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch)
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_voices_with_language_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
AudioService,
|
||||
@ -288,5 +281,5 @@ def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.Monk
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
|
||||
@ -1,27 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.app import audio as audio_module
|
||||
from controllers.console.app.error import AudioTooLargeError
|
||||
from services.errors.audio import AudioTooLargeServiceError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_audio_to_text_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
response_payload = {"text": "hello"}
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
|
||||
@ -35,14 +28,14 @@ def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
data=data,
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
|
||||
assert response == response_payload
|
||||
|
||||
|
||||
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_audio_to_text_maps_audio_too_large(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
@ -60,12 +53,12 @@ def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
with pytest.raises(AudioTooLargeError):
|
||||
method(app_model=app_model)
|
||||
method(api, app_model=app_model)
|
||||
|
||||
|
||||
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
@ -76,14 +69,14 @@ def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method="POST",
|
||||
json={"text": "hello"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
|
||||
assert response == {"audio": "ok"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_voices_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
|
||||
|
||||
@ -94,14 +87,14 @@ def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> N
|
||||
method="GET",
|
||||
query_string={"language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
|
||||
assert response == ["voice-1"]
|
||||
|
||||
|
||||
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_audio_to_text_with_invalid_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageAudioApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
|
||||
|
||||
@ -115,13 +108,13 @@ def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -
|
||||
content_type="multipart/form-data",
|
||||
):
|
||||
# Should not raise, AudioService is mocked
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
assert response == {"text": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_with_language_param(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.ChatMessageTextApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
|
||||
|
||||
@ -132,13 +125,13 @@ def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch)
|
||||
method="POST",
|
||||
json={"text": "hello", "language": "en-US"},
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
assert response == {"audio": "test"}
|
||||
|
||||
|
||||
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_text_to_audio_voices_with_language_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = audio_module.TextModesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
monkeypatch.setattr(
|
||||
audio_module.AudioService,
|
||||
@ -152,5 +145,5 @@ def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.Monk
|
||||
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
|
||||
method="GET",
|
||||
):
|
||||
response = method(app_model=app_model)
|
||||
response = method(api, app_model=app_model)
|
||||
assert isinstance(response, list)
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.console.app import conversation as conversation_module
|
||||
@ -11,22 +13,13 @@ from models.model import AppMode
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _make_account():
|
||||
return SimpleNamespace(timezone="UTC", id="u1")
|
||||
|
||||
|
||||
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_completion_conversation_list_returns_paginated_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
@ -40,14 +33,14 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch:
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
|
||||
response = method(account, app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, account, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
|
||||
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_completion_conversation_list_invalid_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(
|
||||
@ -62,12 +55,12 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes
|
||||
query_string={"start": "bad"},
|
||||
):
|
||||
with pytest.raises(BadRequest):
|
||||
method(account, app_model=SimpleNamespace(id="app-1"))
|
||||
method(api, account, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_chat_conversation_list_advanced_chat_calls_paginate(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.ChatConversationApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
account = _make_account()
|
||||
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
|
||||
@ -81,7 +74,7 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
|
||||
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
|
||||
response = method(account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
response = method(api, account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
|
||||
|
||||
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
@ -114,7 +107,7 @@ def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPat
|
||||
|
||||
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_module.CompletionConversationDetailApi()
|
||||
method = _unwrap(api.delete)
|
||||
method = unwrap(api.delete)
|
||||
|
||||
monkeypatch.setattr(
|
||||
conversation_module.ConversationService,
|
||||
@ -123,4 +116,4 @@ def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.Monke
|
||||
)
|
||||
|
||||
with pytest.raises(NotFound):
|
||||
method(_make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
method(api, _make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@ -12,18 +13,9 @@ from controllers.console.app import conversation_variables as conversation_varia
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_get_conversation_variables_returns_paginated_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
created_at = datetime(2026, 1, 1, tzinfo=UTC)
|
||||
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
|
||||
@ -53,7 +45,7 @@ def test_get_conversation_variables_returns_paginated_response(app: Flask, monke
|
||||
method="GET",
|
||||
query_string={"conversation_id": "conv-1"},
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["page"] == 1
|
||||
assert response["limit"] == 100
|
||||
@ -68,7 +60,7 @@ def test_get_conversation_variables_normalizes_value_type_and_value(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
row = SimpleNamespace(
|
||||
created_at=None,
|
||||
@ -96,7 +88,7 @@ def test_get_conversation_variables_normalizes_value_type_and_value(
|
||||
method="GET",
|
||||
query_string={"conversation_id": "conv-1"},
|
||||
):
|
||||
response = method(app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response["data"][0]["value_type"] == "number"
|
||||
assert response["data"][0]["value"] == "42"
|
||||
@ -104,8 +96,8 @@ def test_get_conversation_variables_normalizes_value_type_and_value(
|
||||
|
||||
def test_get_conversation_variables_requires_conversation_id(app) -> None:
|
||||
api = conversation_variables_module.ConversationVariablesApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
|
||||
with pytest.raises(ValidationError):
|
||||
method(app_model=SimpleNamespace(id="app-1"))
|
||||
method(api, app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
@ -1,24 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.app import generator as generator_module
|
||||
from controllers.console.app.error import ProviderNotInitializeError
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def _model_config_payload():
|
||||
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
|
||||
@ -38,9 +31,9 @@ def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
|
||||
return service
|
||||
|
||||
|
||||
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_rule_generate_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
|
||||
|
||||
@ -49,14 +42,14 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method="POST",
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
response = method("t1")
|
||||
response = method(api, "t1")
|
||||
|
||||
assert response == {"rules": []}
|
||||
|
||||
|
||||
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_rule_code_generate_maps_token_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.RuleCodeGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise ProviderTokenNotInitError("missing token")
|
||||
@ -69,12 +62,12 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc
|
||||
json={"instruction": "do it", "model_config": _model_config_payload()},
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
|
||||
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
@ -89,16 +82,16 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method(session, "t1")
|
||||
response, status = method(api, session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "app app-1 not found"
|
||||
session.get.assert_called_once_with(generator_module.App, "app-1")
|
||||
|
||||
|
||||
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
@ -114,15 +107,15 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method(session, "t1")
|
||||
response, status = method(api, session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "workflow app-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_instruction_generate_node_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
@ -140,15 +133,15 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method(session, "t1")
|
||||
response, status = method(api, session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "node node-1 not found"
|
||||
|
||||
|
||||
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_instruction_generate_code_node(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(id="app-1")
|
||||
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
|
||||
@ -173,16 +166,16 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method(session, "t1")
|
||||
response = method(api, session, "t1")
|
||||
|
||||
assert response == {"code": "x"}
|
||||
assert workflow_service.app_model is app_model
|
||||
assert workflow_service.session is session
|
||||
|
||||
|
||||
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_instruction_generate_legacy_modify(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
session = SimpleNamespace()
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -202,14 +195,14 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response = method(session, "t1")
|
||||
response = method(api, session, "t1")
|
||||
|
||||
assert response == {"instruction": "ok"}
|
||||
|
||||
|
||||
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_instruction_generate_incompatible_params(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.InstructionGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
session = SimpleNamespace()
|
||||
|
||||
with app.test_request_context(
|
||||
@ -223,29 +216,29 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
|
||||
"model_config": _model_config_payload(),
|
||||
},
|
||||
):
|
||||
response, status = method(session, "t1")
|
||||
response, status = method(api, session, "t1")
|
||||
|
||||
assert status == 400
|
||||
assert response["error"] == "incompatible parameters"
|
||||
|
||||
|
||||
def test_instruction_template_prompt(app) -> None:
|
||||
def test_instruction_template_prompt(app: Flask) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
method="POST",
|
||||
json={"type": "prompt"},
|
||||
):
|
||||
response = method()
|
||||
response = method(api)
|
||||
|
||||
assert "data" in response
|
||||
|
||||
|
||||
def test_instruction_template_invalid_type(app) -> None:
|
||||
def test_instruction_template_invalid_type(app: Flask) -> None:
|
||||
api = generator_module.InstructionGenerationTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/instruction-generate/template",
|
||||
@ -253,7 +246,7 @@ def test_instruction_template_invalid_type(app) -> None:
|
||||
json={"type": "unknown"},
|
||||
):
|
||||
with pytest.raises(ValueError):
|
||||
method()
|
||||
method(api)
|
||||
|
||||
|
||||
# ─ /workflow-generate ─────────────────────────────────────────────────────────
|
||||
@ -281,9 +274,9 @@ def _stub_workflow_service(monkeypatch: pytest.MonkeyPatch, returns=None, raises
|
||||
monkeypatch.setattr(generator_module.WorkflowGeneratorService, "generate_workflow_graph", _call)
|
||||
|
||||
|
||||
def test_workflow_generate_returns_service_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_returns_service_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
expected = {
|
||||
"graph": {"nodes": [{"id": "node-1"}], "edges": [], "viewport": {"x": 0, "y": 0, "zoom": 0.7}},
|
||||
@ -297,16 +290,16 @@ def test_workflow_generate_returns_service_result(app, monkeypatch: pytest.Monke
|
||||
method="POST",
|
||||
json=_workflow_generate_payload(),
|
||||
):
|
||||
response = method("t1")
|
||||
response = method(api, "t1")
|
||||
|
||||
assert response == expected
|
||||
|
||||
|
||||
def test_workflow_generate_maps_provider_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_maps_provider_token_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""ProviderTokenNotInitError → ProviderNotInitializeError so the frontend
|
||||
can render the same "provider missing" UX as /rule-generate."""
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_stub_workflow_service(monkeypatch, raises=ProviderTokenNotInitError("missing token"))
|
||||
|
||||
@ -316,15 +309,15 @@ def test_workflow_generate_maps_provider_token_error(app, monkeypatch: pytest.Mo
|
||||
json=_workflow_generate_payload(),
|
||||
):
|
||||
with pytest.raises(ProviderNotInitializeError):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
|
||||
def test_workflow_generate_maps_quota_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_maps_quota_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from controllers.console.app.error import ProviderQuotaExceededError
|
||||
from core.errors.error import QuotaExceededError
|
||||
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_stub_workflow_service(monkeypatch, raises=QuotaExceededError())
|
||||
|
||||
@ -334,15 +327,15 @@ def test_workflow_generate_maps_quota_error(app, monkeypatch: pytest.MonkeyPatch
|
||||
json=_workflow_generate_payload(),
|
||||
):
|
||||
with pytest.raises(ProviderQuotaExceededError):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
|
||||
def test_workflow_generate_maps_model_not_support_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_maps_model_not_support_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from controllers.console.app.error import ProviderModelCurrentlyNotSupportError
|
||||
from core.errors.error import ModelCurrentlyNotSupportError
|
||||
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_stub_workflow_service(monkeypatch, raises=ModelCurrentlyNotSupportError("not supported"))
|
||||
|
||||
@ -352,15 +345,15 @@ def test_workflow_generate_maps_model_not_support_error(app, monkeypatch: pytest
|
||||
json=_workflow_generate_payload(),
|
||||
):
|
||||
with pytest.raises(ProviderModelCurrentlyNotSupportError):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
|
||||
def test_workflow_generate_maps_invoke_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_maps_invoke_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
from controllers.console.app.error import CompletionRequestError
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
_stub_workflow_service(monkeypatch, raises=InvokeError("LLM unreachable"))
|
||||
|
||||
@ -370,13 +363,13 @@ def test_workflow_generate_maps_invoke_error(app, monkeypatch: pytest.MonkeyPatc
|
||||
json=_workflow_generate_payload(),
|
||||
):
|
||||
with pytest.raises(CompletionRequestError):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
|
||||
def test_workflow_generate_accepts_advanced_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_accepts_advanced_chat_mode(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""The payload Literal must accept advanced-chat as well as workflow."""
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@ -397,17 +390,17 @@ def test_workflow_generate_accepts_advanced_chat_mode(app, monkeypatch: pytest.M
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
assert captured["mode"] == "advanced-chat"
|
||||
assert captured["instruction"] == "Summarize a URL"
|
||||
assert captured["ideal_output"] == "A 3-sentence summary."
|
||||
|
||||
|
||||
def test_workflow_generate_forwards_current_graph_for_refine(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_forwards_current_graph_for_refine(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""cmd+k `/refine`: the optional current_graph field reaches the service."""
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@ -429,15 +422,15 @@ def test_workflow_generate_forwards_current_graph_for_refine(app, monkeypatch: p
|
||||
method="POST",
|
||||
json=payload,
|
||||
):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
assert captured["current_graph"] == graph
|
||||
|
||||
|
||||
def test_workflow_generate_current_graph_defaults_to_none(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_generate_current_graph_defaults_to_none(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Omitting current_graph (the `/create` path) forwards None to the service."""
|
||||
api = generator_module.WorkflowGenerateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@ -456,6 +449,6 @@ def test_workflow_generate_current_graph_defaults_to_none(app, monkeypatch: pyte
|
||||
method="POST",
|
||||
json=_workflow_generate_payload(),
|
||||
):
|
||||
method("t1")
|
||||
method(api, "t1")
|
||||
|
||||
assert captured["current_graph"] is None
|
||||
|
||||
@ -7,15 +7,6 @@ import pytest
|
||||
from controllers.console.app import message as message_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test valid ChatMessagesQuery with all fields."""
|
||||
query = message_module.ChatMessagesQuery(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -11,18 +12,9 @@ from controllers.console.app import model_config as model_config_module
|
||||
from models.model import AppMode, AppModelConfig
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
@ -50,7 +42,7 @@ def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method("t1", "u1", app_model=app_model)
|
||||
response = method(api, "t1", "u1", app_model=app_model)
|
||||
|
||||
session.add.assert_called_once()
|
||||
session.flush.assert_called_once()
|
||||
@ -62,7 +54,7 @@ def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.
|
||||
|
||||
def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = model_config_module.ModelConfigResource()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
app_model = SimpleNamespace(
|
||||
id="app-1",
|
||||
@ -137,7 +129,7 @@ def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.Mon
|
||||
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
|
||||
response = method("t1", "u1", app_model=app_model)
|
||||
response = method(api, "t1", "u1", app_model=app_model)
|
||||
|
||||
stored_config = session.add.call_args[0][0]
|
||||
stored_agent_mode = json.loads(stored_config.agent_mode)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
@ -9,15 +10,6 @@ from werkzeug.exceptions import BadRequest
|
||||
from controllers.console.app import statistic as statistic_module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _ConnContext:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
@ -48,42 +40,42 @@ def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 1
|
||||
@ -94,14 +86,14 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey
|
||||
|
||||
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTerminalsStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
|
||||
|
||||
@ -111,13 +103,13 @@ def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypat
|
||||
# This just verifies the decorator is applied correctly
|
||||
# Actual endpoint testing would require complex JOIN mocking
|
||||
api = statistic_module.AverageSessionInteractionStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
assert callable(method)
|
||||
|
||||
|
||||
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
def mock_parse(*args, **kwargs):
|
||||
raise ValueError("Invalid time range")
|
||||
@ -128,12 +120,12 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
with pytest.raises(BadRequest):
|
||||
method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
|
||||
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", message_count=10),
|
||||
@ -144,7 +136,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 3
|
||||
@ -152,20 +144,20 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
|
||||
|
||||
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyMessageStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
_install_common(monkeypatch)
|
||||
_install_db(monkeypatch, [])
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": []}
|
||||
|
||||
|
||||
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyConversationStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
|
||||
_install_db(monkeypatch, rows)
|
||||
@ -177,14 +169,14 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
|
||||
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
|
||||
|
||||
|
||||
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = statistic_module.DailyTokenCostStatistic()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
rows = [
|
||||
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
|
||||
@ -194,7 +186,7 @@ def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.Monk
|
||||
_install_db(monkeypatch, rows)
|
||||
|
||||
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
|
||||
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
|
||||
|
||||
data = response.get_json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
@ -7,6 +8,7 @@ from typing import cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import HTTPException, NotFound
|
||||
|
||||
@ -17,12 +19,6 @@ from graphon.variables import SecretVariable, StringVariable
|
||||
from graphon.variables.variables import RAGPipelineVariable
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_workflow(**overrides):
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-1",
|
||||
@ -107,24 +103,20 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
build_mock.assert_called_once()
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_sync_draft_workflow_invalid_content_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
handler(api, "t1", app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert exc.value.code == 415
|
||||
|
||||
|
||||
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_sync_draft_workflow_invalid_json(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
@ -132,19 +124,19 @@ def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch)
|
||||
data="[]",
|
||||
content_type="application/json",
|
||||
):
|
||||
response, status = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
response, status = handler(api, "t1", app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert status == 400
|
||||
assert response["message"] == "Invalid JSON data"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_sync_draft_workflow_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="h",
|
||||
updated_at=None,
|
||||
created_at=datetime(2024, 1, 1),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
monkeypatch.setattr(
|
||||
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
|
||||
)
|
||||
@ -156,20 +148,19 @@ def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> No
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
method="POST",
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
response = handler(api, "t1", app_model=SimpleNamespace(id="app"))
|
||||
|
||||
assert response["result"] == "success"
|
||||
|
||||
|
||||
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
def test_sync_draft_workflow_hash_mismatch(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
def _raise(*_args, **_kwargs):
|
||||
raise workflow_module.WorkflowHashNotEqualError()
|
||||
@ -178,7 +169,7 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch)
|
||||
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft",
|
||||
@ -186,10 +177,10 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch)
|
||||
json={"graph": {}, "features": {}, "hash": "h"},
|
||||
):
|
||||
with pytest.raises(DraftWorkflowNotSync):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
handler(api, "t1", app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_restore_published_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
workflow = SimpleNamespace(
|
||||
unique_hash="restored-hash",
|
||||
updated_at=None,
|
||||
@ -197,7 +188,6 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
|
||||
)
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -205,7 +195,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/published-workflow/restore",
|
||||
@ -213,6 +203,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
|
||||
):
|
||||
response = handler(
|
||||
api,
|
||||
"t1",
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="published-workflow",
|
||||
)
|
||||
@ -221,10 +212,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
|
||||
assert response["hash"] == "restored-hash"
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
def test_restore_published_workflow_to_draft_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -236,7 +224,7 @@ def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/published-workflow/restore",
|
||||
@ -245,15 +233,15 @@ def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.
|
||||
with pytest.raises(NotFound):
|
||||
handler(
|
||||
api,
|
||||
"t1",
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="published-workflow",
|
||||
)
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -268,7 +256,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/draft-workflow/restore",
|
||||
@ -277,6 +265,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(
|
||||
api,
|
||||
"t1",
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="draft-workflow",
|
||||
)
|
||||
@ -286,11 +275,8 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m
|
||||
|
||||
|
||||
def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
user = SimpleNamespace(id="account-1")
|
||||
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -302,7 +288,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowRestoreApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/workflows/published-workflow/restore",
|
||||
@ -311,6 +297,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
handler(
|
||||
api,
|
||||
"t1",
|
||||
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
|
||||
workflow_id="published-workflow",
|
||||
)
|
||||
@ -319,9 +306,11 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
|
||||
assert exc.value.description == "invalid workflow graph"
|
||||
|
||||
|
||||
def test_get_published_workflows_serializes_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_get_published_workflows_serializes_items_before_session_closes(
|
||||
app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = workflow_module.PublishedAllWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = inspect.unwrap(api.get)
|
||||
|
||||
session_state = {"open": False}
|
||||
|
||||
@ -351,7 +340,6 @@ def test_get_published_workflows_serializes_items_before_session_closes(app, mon
|
||||
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
monkeypatch.setattr(
|
||||
workflow_module,
|
||||
"WorkflowService",
|
||||
@ -365,7 +353,7 @@ def test_get_published_workflows_serializes_items_before_session_closes(app, mon
|
||||
method="GET",
|
||||
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
|
||||
):
|
||||
response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
|
||||
response = handler(api, "t1", app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
|
||||
|
||||
assert response["items"][0]["id"] == "w1"
|
||||
assert response["page"] == 1
|
||||
@ -380,7 +368,7 @@ def test_draft_workflow_get_serializes_response_model(monkeypatch: pytest.Monkey
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = inspect.unwrap(api.get)
|
||||
|
||||
response = handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
@ -522,13 +510,13 @@ def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
)
|
||||
|
||||
api = workflow_module.DraftWorkflowApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = inspect.unwrap(api.get)
|
||||
|
||||
with pytest.raises(DraftWorkflowNotExist):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_advanced_chat_run_conversation_not_exists(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
workflow_module.AppGenerateService,
|
||||
"generate",
|
||||
@ -536,10 +524,9 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk
|
||||
workflow_module.services.errors.conversation.ConversationNotExistsError()
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
|
||||
|
||||
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/app/advanced-chat/workflows/draft/run",
|
||||
@ -547,10 +534,10 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk
|
||||
json={"inputs": {}},
|
||||
):
|
||||
with pytest.raises(NotFound):
|
||||
handler(api, app_model=SimpleNamespace(id="app"))
|
||||
handler(api, "t1", app_model=SimpleNamespace(id="app"))
|
||||
|
||||
|
||||
def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_id_1 = "11111111-1111-1111-1111-111111111111"
|
||||
app_id_2 = "22222222-2222-2222-2222-222222222222"
|
||||
signed_avatar_url = "https://files.example.com/signed/avatar-1"
|
||||
@ -602,7 +589,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p
|
||||
monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory)
|
||||
|
||||
api = workflow_module.WorkflowOnlineUsersApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/workflows/online-users",
|
||||
@ -636,7 +623,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p
|
||||
sign_avatar.assert_called_once_with("avatar-file-id")
|
||||
|
||||
|
||||
def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)]
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
monkeypatch.setattr(
|
||||
@ -653,7 +640,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk
|
||||
monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory)
|
||||
|
||||
api = workflow_module.WorkflowOnlineUsersApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/workflows/online-users",
|
||||
@ -668,7 +655,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk
|
||||
assert second_pipeline.hgetall.call_count == 1
|
||||
|
||||
|
||||
def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
|
||||
accessible_app_ids = Mock(return_value=set())
|
||||
monkeypatch.setattr(
|
||||
@ -680,7 +667,7 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch:
|
||||
excessive_ids = [f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS + 1)]
|
||||
|
||||
api = workflow_module.WorkflowOnlineUsersApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = inspect.unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/apps/workflows/online-users",
|
||||
|
||||
@ -99,7 +99,7 @@ class MutationResponseCase:
|
||||
expected_status: int | None = None
|
||||
|
||||
|
||||
def _unwrap_response(result: object) -> tuple[dict[str, object], int | None]:
|
||||
def unwrap_response(result: object) -> tuple[dict[str, object], int | None]:
|
||||
if isinstance(result, tuple):
|
||||
response, status = result
|
||||
assert isinstance(response, dict)
|
||||
@ -194,7 +194,7 @@ def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatc
|
||||
with _patch_payload(payload):
|
||||
result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123")
|
||||
|
||||
response, status = _unwrap_response(result)
|
||||
response, status = unwrap_response(result)
|
||||
assert response["id"] == "comment-1"
|
||||
assert status == 201
|
||||
create_comment_mock.assert_called_once_with(
|
||||
@ -224,7 +224,7 @@ def test_update_comment_omits_mentions_when_payload_does_not_include_them(
|
||||
with _patch_payload(payload):
|
||||
result = workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1")
|
||||
|
||||
response, status = _unwrap_response(result)
|
||||
response, status = unwrap_response(result)
|
||||
assert response == {"id": "comment-1", "updated_at": JAN_1_2024_NOON_TS}
|
||||
assert status is None
|
||||
update_comment_mock.assert_called_once_with(
|
||||
@ -480,7 +480,7 @@ def test_mutation_endpoints_serialize_response_models(
|
||||
with _patch_payload(case.payload):
|
||||
result = getattr(case.resource_cls(), case.method_name)(**case.kwargs)
|
||||
|
||||
response, status = _unwrap_response(result)
|
||||
response, status = unwrap_response(result)
|
||||
assert response == case.expected_response
|
||||
assert status == case.expected_status
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
@ -12,12 +13,6 @@ from controllers.console.app import workflow_run as workflow_run_module
|
||||
from models import Account
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _serialize_200_response(handler, payload: Any) -> Any:
|
||||
response_doc = getattr(handler, "__apidoc__", {}).get("responses", {}).get("200")
|
||||
if response_doc is None:
|
||||
@ -100,7 +95,7 @@ def test_workflow_run_list_returns_frontend_history_contract(app: Flask, monkeyp
|
||||
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
|
||||
|
||||
api = workflow_run_module.WorkflowRunListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/apps/app-1/workflow-runs?limit=10", method="GET"):
|
||||
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
@ -141,7 +136,7 @@ def test_advanced_chat_workflow_run_list_keeps_message_fields(app: Flask, monkey
|
||||
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
|
||||
|
||||
api = workflow_run_module.AdvancedChatAppWorkflowRunListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/apps/app-1/advanced-chat/workflow-runs?limit=1", method="GET"):
|
||||
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
@ -180,7 +175,7 @@ def test_workflow_run_detail_returns_frontend_detail_contract(app: Flask, monkey
|
||||
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
|
||||
|
||||
api = workflow_run_module.WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/apps/app-1/workflow-runs/run-1", method="GET"):
|
||||
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1")
|
||||
@ -217,7 +212,7 @@ def test_workflow_run_node_executions_return_frontend_trace_contract(
|
||||
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
|
||||
|
||||
api = workflow_run_module.WorkflowRunNodeExecutionListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"):
|
||||
payload = handler(
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
@ -12,12 +13,6 @@ from controllers.console.auth.data_source_bearer_auth import (
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _payload_patch(payload: dict):
|
||||
return patch.object(
|
||||
type(console_ns),
|
||||
@ -29,7 +24,7 @@ def _payload_patch(payload: dict):
|
||||
|
||||
def test_list_data_source_auth_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSource()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
binding = SimpleNamespace(
|
||||
id="binding-1",
|
||||
category="api_key",
|
||||
@ -52,7 +47,7 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None:
|
||||
|
||||
def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBinding()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
payload = {
|
||||
"category": "api_key",
|
||||
"provider": "custom",
|
||||
@ -73,7 +68,7 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
|
||||
def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBindingDelete()
|
||||
method = _unwrap(api.delete)
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
|
||||
|
||||
@ -41,13 +41,7 @@ def encode_code(code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
from inspect import unwrap
|
||||
|
||||
|
||||
class TestLoginApi:
|
||||
@ -510,7 +504,7 @@ class TestLogoutApi:
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = _unwrap(logout_api.post)(mock_account)
|
||||
response = unwrap(logout_api.post)(logout_api, mock_account)
|
||||
|
||||
# Assert
|
||||
mock_service_logout.assert_called_once_with(account=mock_account)
|
||||
@ -536,7 +530,7 @@ class TestLogoutApi:
|
||||
# Act
|
||||
with app.test_request_context("/logout", method="POST"):
|
||||
logout_api = LogoutApi()
|
||||
response = _unwrap(logout_api.post)(anonymous_user)
|
||||
response = unwrap(logout_api.post)(logout_api, anonymous_user)
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from inspect import unwrap
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi
|
||||
@ -8,12 +9,6 @@ from models.account import AccountStatus, TenantAccountRole
|
||||
from models.model import OAuthProviderApp
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_account() -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
@ -38,7 +33,7 @@ def _make_oauth_provider_app() -> OAuthProviderApp:
|
||||
|
||||
def test_oauth_authorize_uses_injected_current_user() -> None:
|
||||
api = OAuthServerUserAuthorizeApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
account = _make_account()
|
||||
oauth_provider_app = _make_oauth_provider_app()
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ class TestRefreshTokenApi:
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
|
||||
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app: Flask, mock_token_pair):
|
||||
"""
|
||||
Test successful token refresh flow.
|
||||
|
||||
@ -170,7 +170,7 @@ class TestRefreshTokenApi:
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
|
||||
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app: Flask, mock_token_pair):
|
||||
"""
|
||||
Test that token refresh updates all three tokens.
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
from inspect import unwrap
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
@ -20,12 +19,6 @@ from models.dataset import PipelineCustomizedTemplate
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||
|
||||
|
||||
def _unwrap(func: object) -> Callable[..., Any]:
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return cast(Callable[..., Any], func)
|
||||
|
||||
|
||||
def _template_item() -> dict[str, object]:
|
||||
return {
|
||||
"id": "template-1",
|
||||
@ -60,7 +53,7 @@ def _payload() -> dict[str, object]:
|
||||
class TestPipelineTemplateListApi:
|
||||
def test_get_uses_query_defaults_and_serializes_nullable_fields(self, app: Flask) -> None:
|
||||
api = PipelineTemplateListApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
service_calls: list[tuple[str, str]] = []
|
||||
|
||||
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
|
||||
@ -87,7 +80,7 @@ class TestPipelineTemplateListApi:
|
||||
|
||||
def test_get_passes_explicit_query_to_service(self, app: Flask) -> None:
|
||||
api = PipelineTemplateListApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
service_calls: list[tuple[str, str]] = []
|
||||
|
||||
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
|
||||
@ -108,7 +101,7 @@ class TestPipelineTemplateListApi:
|
||||
class TestPipelineTemplateDetailApi:
|
||||
def test_get_serializes_template_detail(self, app: Flask) -> None:
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
service_calls: list[tuple[str, str]] = []
|
||||
|
||||
class Service:
|
||||
@ -128,7 +121,7 @@ class TestPipelineTemplateDetailApi:
|
||||
|
||||
def test_get_raises_not_found_without_custom_response_body(self, app: Flask) -> None:
|
||||
api = PipelineTemplateDetailApi()
|
||||
method = _unwrap(api.get)
|
||||
method = unwrap(api.get)
|
||||
|
||||
class Service:
|
||||
def get_pipeline_template_detail(self, template_id: str, template_type: str) -> None:
|
||||
@ -145,7 +138,7 @@ class TestPipelineTemplateDetailApi:
|
||||
class TestCustomizedPipelineTemplateApi:
|
||||
def test_patch_validates_payload_and_returns_empty_204(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.patch)
|
||||
method = unwrap(api.patch)
|
||||
payload = _payload()
|
||||
service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = []
|
||||
|
||||
@ -174,7 +167,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
|
||||
def test_patch_defaults_missing_icon_info_before_service_call(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.patch)
|
||||
method = unwrap(api.patch)
|
||||
payload: dict[str, object] = {
|
||||
"name": "Updated template",
|
||||
"description": "Updated description",
|
||||
@ -204,7 +197,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
|
||||
def test_delete_returns_empty_204(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.delete)
|
||||
method = unwrap(api.delete)
|
||||
deleted_template_ids: list[str] = []
|
||||
|
||||
def delete_template(template_id: str) -> None:
|
||||
@ -221,7 +214,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
|
||||
def test_post_exports_yaml_from_orm_template(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
template = PipelineCustomizedTemplate(
|
||||
tenant_id="00000000-0000-0000-0000-000000000001",
|
||||
name="Template",
|
||||
@ -265,7 +258,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
|
||||
def test_post_raises_when_template_is_missing(self, app: Flask) -> None:
|
||||
api = CustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
|
||||
class Session:
|
||||
def scalar(self, stmt: object) -> None:
|
||||
@ -297,7 +290,7 @@ class TestCustomizedPipelineTemplateApi:
|
||||
class TestPublishCustomizedPipelineTemplateApi:
|
||||
def test_post_validates_payload_and_returns_empty_204(self, app: Flask) -> None:
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
payload = _payload()
|
||||
service_calls: list[tuple[str, dict[str, object]]] = []
|
||||
|
||||
@ -317,7 +310,7 @@ class TestPublishCustomizedPipelineTemplateApi:
|
||||
|
||||
def test_post_allows_missing_icon_info_for_publish_service_fallback(self, app: Flask) -> None:
|
||||
api = PublishCustomizedPipelineTemplateApi()
|
||||
method = _unwrap(api.post)
|
||||
method = unwrap(api.post)
|
||||
payload: dict[str, object] = {
|
||||
"name": "Published template",
|
||||
"description": "Description",
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
@ -9,12 +10,6 @@ import pytest
|
||||
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_workflow(**overrides):
|
||||
workflow = SimpleNamespace(
|
||||
id="workflow-1",
|
||||
@ -45,7 +40,7 @@ def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch:
|
||||
)
|
||||
|
||||
api = module.DraftRagPipelineApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
|
||||
|
||||
@ -63,7 +58,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = module.PublishedAllRagPipelineApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
session_state = {"open": False}
|
||||
|
||||
class _SessionContext:
|
||||
@ -133,7 +128,7 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch:
|
||||
payload: dict[str, object] = {"marked_name": "Updated release"}
|
||||
|
||||
api = module.RagPipelineByIdApi()
|
||||
handler = _unwrap(api.patch)
|
||||
handler = unwrap(api.patch)
|
||||
|
||||
with (
|
||||
app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload),
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -23,12 +24,6 @@ from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def test_jsonify_form_definition() -> None:
|
||||
expiration = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
|
||||
@ -64,7 +59,7 @@ def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
response = handler(api, "tenant-1", form_token="token")
|
||||
@ -85,7 +80,7 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
@ -106,7 +101,7 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
@ -137,7 +132,7 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
@ -172,7 +167,7 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context(
|
||||
"/console/api/form/human_input/token",
|
||||
@ -244,7 +239,7 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch)
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
@ -271,7 +266,7 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
@ -298,7 +293,7 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
with pytest.raises(NotFoundError):
|
||||
@ -342,7 +337,7 @@ def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -
|
||||
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = ConsoleWorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
|
||||
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
|
||||
response = handler(api, "t1", SimpleNamespace(id="user-1"), workflow_run_id="run-1")
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import urllib.parse
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -16,12 +17,6 @@ from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
|
||||
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _make_account(account_id: str = "u1") -> Account:
|
||||
account = Account(
|
||||
name="Test User",
|
||||
@ -89,7 +84,7 @@ def _mock_upload_dependencies(
|
||||
|
||||
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
@ -110,7 +105,7 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest
|
||||
|
||||
def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
target_url = "http://example.com/api/aiagent/httpview/txt"
|
||||
query = "fileNameKey=cankao1_ce4305bc-be20-4c5d-8732-de1741d28e27"
|
||||
|
||||
@ -131,7 +126,7 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch:
|
||||
|
||||
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.GetRemoteFileInfo()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
decoded_url = "https://example.com/test.txt"
|
||||
encoded_url = urllib.parse.quote(decoded_url, safe="")
|
||||
|
||||
@ -154,7 +149,7 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo
|
||||
|
||||
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/report.txt"
|
||||
|
||||
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
|
||||
@ -196,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/photo.jpg"
|
||||
|
||||
head_resp = _FakeResponse(status_code=200, method="HEAD", content=b"head-content")
|
||||
@ -227,7 +222,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
|
||||
|
||||
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
make_request = MagicMock(
|
||||
@ -245,7 +240,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
|
||||
|
||||
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/fail.txt"
|
||||
|
||||
request = httpx.Request("HEAD", url)
|
||||
@ -259,7 +254,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
|
||||
|
||||
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
|
||||
@ -274,7 +269,7 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
|
||||
|
||||
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/large.bin"
|
||||
|
||||
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
|
||||
@ -289,7 +284,7 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
|
||||
|
||||
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = remote_files_module.RemoteFileUpload()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
url = "https://example.com/file.exe"
|
||||
|
||||
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
|
||||
|
||||
@ -10,6 +10,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi.auth.data import AuthData
|
||||
@ -30,7 +31,7 @@ def _make_auth_data(app_model, caller, caller_kind):
|
||||
|
||||
|
||||
class TestOpenApiHumanInputFormGet:
|
||||
def test_get_success(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_get_success(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
definition = SimpleNamespace(
|
||||
@ -74,7 +75,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
assert payload["user_actions"] == [{"id": "submit", "title": "Submit"}]
|
||||
service_mock.ensure_form_active.assert_called_once_with(form)
|
||||
|
||||
def test_get_form_not_found(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_get_form_not_found(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
service_mock = Mock()
|
||||
@ -96,7 +97,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_get_form_wrong_app(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = SimpleNamespace(
|
||||
@ -121,7 +122,7 @@ class TestOpenApiHumanInputFormGet:
|
||||
auth_data=_make_auth_data(app_model, caller, "account"),
|
||||
)
|
||||
|
||||
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_get_form_wrong_surface(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = SimpleNamespace(
|
||||
@ -159,7 +160,7 @@ class TestOpenApiHumanInputFormPost:
|
||||
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
def test_post_account_caller_uses_user_id(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_post_account_caller_uses_user_id(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = self._make_form()
|
||||
@ -196,7 +197,7 @@ class TestOpenApiHumanInputFormPost:
|
||||
)
|
||||
assert result == ({}, 200)
|
||||
|
||||
def test_post_end_user_caller_uses_end_user_id(self, app, bypass_pipeline, monkeypatch):
|
||||
def test_post_end_user_caller_uses_end_user_id(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
|
||||
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
|
||||
|
||||
form = self._make_form()
|
||||
|
||||
@ -13,6 +13,7 @@ Note: API endpoint tests for annotation controllers are complex due to:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -34,13 +35,6 @@ from extensions.ext_redis import redis_client
|
||||
from models.model import App
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pydantic Model Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -193,7 +187,7 @@ class TestAnnotationReplyActionApi:
|
||||
monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock)
|
||||
|
||||
api = AnnotationReplyActionApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(
|
||||
@ -211,7 +205,7 @@ class TestAnnotationReplyActionApi:
|
||||
monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock)
|
||||
|
||||
api = AnnotationReplyActionApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(
|
||||
@ -230,7 +224,7 @@ class TestAnnotationReplyActionStatusApi:
|
||||
monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None)
|
||||
|
||||
api = AnnotationReplyActionStatusApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
@ -245,7 +239,7 @@ class TestAnnotationReplyActionStatusApi:
|
||||
monkeypatch.setattr(redis_client, "get", _get)
|
||||
|
||||
api = AnnotationReplyActionStatusApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
response, status = handler(api, app_model=app_model, job_id="j1", action="enable")
|
||||
@ -262,7 +256,7 @@ class TestAnnotationListApi:
|
||||
monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations", method="GET"):
|
||||
@ -278,7 +272,7 @@ class TestAnnotationListApi:
|
||||
monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations?page=2&limit=5&keyword=refund", method="GET"):
|
||||
@ -297,7 +291,7 @@ class TestAnnotationListApi:
|
||||
monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context(f"/apps/annotations?{query_string}", method="GET"):
|
||||
@ -315,7 +309,7 @@ class TestAnnotationListApi:
|
||||
)
|
||||
|
||||
api = AnnotationListApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}):
|
||||
@ -337,8 +331,8 @@ class TestAnnotationUpdateDeleteApi:
|
||||
monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock)
|
||||
|
||||
api = AnnotationUpdateDeleteApi()
|
||||
put_handler = _unwrap(api.put)
|
||||
delete_handler = _unwrap(api.delete)
|
||||
put_handler = unwrap(api.put)
|
||||
delete_handler = unwrap(api.delete)
|
||||
app_model = SimpleNamespace(id="app")
|
||||
|
||||
with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}):
|
||||
|
||||
@ -9,6 +9,7 @@ Tests coverage for:
|
||||
|
||||
import io
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -41,12 +42,6 @@ from services.errors.audio import (
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
def _file_data():
|
||||
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
|
||||
|
||||
@ -194,7 +189,7 @@ class TestAudioApi:
|
||||
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
@ -220,7 +215,7 @@ class TestAudioApi:
|
||||
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
|
||||
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
@ -233,7 +228,7 @@ class TestAudioApi:
|
||||
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
api = AudioApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
@ -247,7 +242,7 @@ class TestTextApi:
|
||||
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
|
||||
|
||||
api = TextApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
|
||||
@ -266,7 +261,7 @@ class TestTextApi:
|
||||
)
|
||||
|
||||
api = TextApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
end_user = SimpleNamespace(external_user_id="ext")
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ Focus on:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -44,12 +45,6 @@ from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestCompletionRequestPayload:
|
||||
"""Test suite for CompletionRequestPayload Pydantic model."""
|
||||
|
||||
@ -426,7 +421,7 @@ class TestChatRequestPayloadController:
|
||||
class TestCompletionApiController:
|
||||
def test_wrong_mode(self, app: Flask) -> None:
|
||||
api = CompletionApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -444,7 +439,7 @@ class TestCompletionApiController:
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
api = CompletionApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
|
||||
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
|
||||
with pytest.raises(NotFound):
|
||||
@ -454,7 +449,7 @@ class TestCompletionApiController:
|
||||
class TestCompletionStopApiController:
|
||||
def test_wrong_mode(self, app: Flask) -> None:
|
||||
api = CompletionStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
@ -467,7 +462,7 @@ class TestCompletionStopApiController:
|
||||
monkeypatch.setattr(AppTaskService, "stop_task", stop_mock)
|
||||
|
||||
api = CompletionStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
@ -481,7 +476,7 @@ class TestCompletionStopApiController:
|
||||
class TestChatApiController:
|
||||
def test_wrong_mode(self, app: Flask) -> None:
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -497,7 +492,7 @@ class TestChatApiController:
|
||||
)
|
||||
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -513,7 +508,7 @@ class TestChatApiController:
|
||||
)
|
||||
|
||||
api = ChatApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -525,7 +520,7 @@ class TestChatApiController:
|
||||
class TestChatStopApiController:
|
||||
def test_wrong_mode(self, app: Flask) -> None:
|
||||
api = ChatStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ Focus on:
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -51,12 +52,6 @@ from services.errors.conversation import (
|
||||
)
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestConversationListQuery:
|
||||
"""Test suite for ConversationListQuery Pydantic model."""
|
||||
|
||||
@ -380,7 +375,7 @@ class TestConversationAppModeValidation:
|
||||
app raises NotChatAppError.
|
||||
"""
|
||||
app = Mock(spec=App)
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
|
||||
app_mode = AppMode.value_of(app.mode)
|
||||
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}
|
||||
@ -498,7 +493,7 @@ class TestConversationPayloadsController:
|
||||
class TestConversationApiController:
|
||||
def test_list_not_chat(self, app: Flask) -> None:
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -531,7 +526,7 @@ class TestConversationApiController:
|
||||
monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub)
|
||||
|
||||
api = ConversationApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -546,7 +541,7 @@ class TestConversationApiController:
|
||||
class TestConversationDetailApiController:
|
||||
def test_delete_not_chat(self, app: Flask) -> None:
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
handler = unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -562,7 +557,7 @@ class TestConversationDetailApiController:
|
||||
)
|
||||
|
||||
api = ConversationDetailApi()
|
||||
handler = _unwrap(api.delete)
|
||||
handler = unwrap(api.delete)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -580,7 +575,7 @@ class TestConversationRenameApiController:
|
||||
)
|
||||
|
||||
api = ConversationRenameApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -596,7 +591,7 @@ class TestConversationRenameApiController:
|
||||
class TestConversationVariablesApiController:
|
||||
def test_not_chat(self, app: Flask) -> None:
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -612,7 +607,7 @@ class TestConversationVariablesApiController:
|
||||
)
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -645,7 +640,7 @@ class TestConversationVariablesApiController:
|
||||
)
|
||||
|
||||
api = ConversationVariablesApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -671,7 +666,7 @@ class TestConversationVariableDetailApiController:
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
handler = unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -697,7 +692,7 @@ class TestConversationVariableDetailApiController:
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
handler = unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -731,7 +726,7 @@ class TestConversationVariableDetailApiController:
|
||||
)
|
||||
|
||||
api = ConversationVariableDetailApi()
|
||||
handler = _unwrap(api.put)
|
||||
handler = unwrap(api.put)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ class TestFileUploadResponse:
|
||||
# unwrapped method directly to bypass the decorator.
|
||||
# =============================================================================
|
||||
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
from inspect import unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -274,7 +274,7 @@ class TestFileApiPost:
|
||||
data=data,
|
||||
):
|
||||
api = FileApi()
|
||||
response, status = _unwrap(api.post)(
|
||||
response, status = unwrap(api.post)(
|
||||
api,
|
||||
app_model=mock_app_model,
|
||||
end_user=mock_end_user,
|
||||
@ -295,7 +295,7 @@ class TestFileApiPost:
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(NoFileUploadedError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user):
|
||||
"""Test TooManyFilesError when multiple files uploaded."""
|
||||
@ -316,7 +316,7 @@ class TestFileApiPost:
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(TooManyFilesError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user):
|
||||
"""Test UnsupportedFileTypeError when file has no mimetype."""
|
||||
@ -334,7 +334,7 @@ class TestFileApiPost:
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
@ -366,7 +366,7 @@ class TestFileApiPost:
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(FileTooLargeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@patch("controllers.service_api.app.file.FileService")
|
||||
@patch("controllers.service_api.app.file.db")
|
||||
@ -396,4 +396,4 @@ class TestFileApiPost:
|
||||
):
|
||||
api = FileApi()
|
||||
with pytest.raises(UnsupportedFileTypeError):
|
||||
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
|
||||
|
||||
@ -7,6 +7,7 @@ import sys
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from typing import override
|
||||
from unittest.mock import ANY, MagicMock, Mock
|
||||
@ -44,7 +45,6 @@ from repositories.api_workflow_node_execution_repository import WorkflowNodeExec
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.workflow_event_snapshot_service import _build_snapshot_events
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class _DummyRateLimit:
|
||||
@ -275,8 +275,8 @@ class TestHitlServiceApi:
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
|
||||
@ -310,8 +310,8 @@ class TestHitlServiceApi:
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context(
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -15,7 +16,6 @@ from werkzeug.exceptions import NotFound
|
||||
from controllers.common.human_input import HumanInputFormSubmitPayload
|
||||
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
|
||||
from models.human_input import RecipientType
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
class TestWorkflowHumanInputFormApi:
|
||||
@ -45,7 +45,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
@ -98,7 +98,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
@ -121,7 +121,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
@ -153,7 +153,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with app.test_request_context("/form/human_input/token-1", method="GET"):
|
||||
@ -175,7 +175,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
@ -209,7 +209,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
inputs = {
|
||||
@ -285,7 +285,7 @@ class TestWorkflowHumanInputFormApi:
|
||||
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
|
||||
|
||||
api = WorkflowHumanInputFormApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ Focus on:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -43,12 +44,6 @@ from services.errors.message import (
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
return func
|
||||
|
||||
|
||||
class TestMessageListQuery:
|
||||
"""Test suite for MessageListQuery Pydantic model."""
|
||||
|
||||
@ -383,7 +378,7 @@ class TestMessageService:
|
||||
class TestMessageListApi:
|
||||
def test_not_chat_app(self, app: Flask) -> None:
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -399,7 +394,7 @@ class TestMessageListApi:
|
||||
)
|
||||
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -418,7 +413,7 @@ class TestMessageListApi:
|
||||
)
|
||||
|
||||
api = MessageListApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -439,7 +434,7 @@ class TestMessageFeedbackApi:
|
||||
)
|
||||
|
||||
api = MessageFeedbackApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace()
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -457,7 +452,7 @@ class TestAppGetFeedbacksApi:
|
||||
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
|
||||
|
||||
api = AppGetFeedbacksApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"):
|
||||
@ -469,7 +464,7 @@ class TestAppGetFeedbacksApi:
|
||||
class TestMessageSuggestedApi:
|
||||
def test_not_chat(self, app: Flask) -> None:
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -485,7 +480,7 @@ class TestMessageSuggestedApi:
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -501,7 +496,7 @@ class TestMessageSuggestedApi:
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -517,7 +512,7 @@ class TestMessageSuggestedApi:
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -533,7 +528,7 @@ class TestMessageSuggestedApi:
|
||||
)
|
||||
|
||||
api = MessageSuggestedApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ Focus on:
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@ -369,7 +370,7 @@ class TestWorkflowRunRepository:
|
||||
class TestWorkflowRunDetailApi:
|
||||
def test_not_workflow_app(self, app: Flask) -> None:
|
||||
api = WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
|
||||
with app.test_request_context("/workflows/run/1", method="GET"):
|
||||
@ -388,8 +389,8 @@ class TestWorkflowRunDetailApi:
|
||||
)
|
||||
|
||||
api = WorkflowRunDetailApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW, tenant_id="t1", id="a1")
|
||||
|
||||
result = handler(api, app_model=app_model, workflow_run_id="run")
|
||||
assert result["id"] == "run"
|
||||
@ -400,7 +401,7 @@ class TestWorkflowRunDetailApi:
|
||||
class TestWorkflowRunApi:
|
||||
def test_not_workflow_app(self, app: Flask) -> None:
|
||||
api = WorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -416,8 +417,8 @@ class TestWorkflowRunApi:
|
||||
)
|
||||
|
||||
api = WorkflowRunApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
|
||||
@ -434,8 +435,8 @@ class TestWorkflowRunByIdApi:
|
||||
)
|
||||
|
||||
api = WorkflowRunByIdApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
|
||||
@ -450,8 +451,8 @@ class TestWorkflowRunByIdApi:
|
||||
)
|
||||
|
||||
api = WorkflowRunByIdApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
|
||||
@ -462,7 +463,7 @@ class TestWorkflowRunByIdApi:
|
||||
class TestWorkflowTaskStopApi:
|
||||
def test_wrong_mode(self, app: Flask) -> None:
|
||||
api = WorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace()
|
||||
|
||||
@ -477,8 +478,8 @@ class TestWorkflowTaskStopApi:
|
||||
monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock)
|
||||
|
||||
api = WorkflowTaskStopApi()
|
||||
handler = _unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.post)
|
||||
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="u1")
|
||||
|
||||
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
|
||||
@ -515,7 +516,7 @@ class TestWorkflowAppLogApi:
|
||||
)
|
||||
|
||||
api = WorkflowAppLogApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="a1")
|
||||
|
||||
with app.test_request_context("/workflows/logs", method="GET"):
|
||||
@ -533,15 +534,13 @@ class TestWorkflowAppLogApi:
|
||||
# directly to bypass the decorator.
|
||||
# =============================================================================
|
||||
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_workflow_app():
|
||||
app = Mock(spec=App)
|
||||
app.id = str(uuid.uuid4())
|
||||
app.tenant_id = str(uuid.uuid4())
|
||||
app.mode = AppMode.WORKFLOW.value
|
||||
app.mode = AppMode.WORKFLOW
|
||||
return app
|
||||
|
||||
|
||||
@ -574,7 +573,7 @@ class TestWorkflowRunDetailApiGet:
|
||||
method="GET",
|
||||
):
|
||||
api = WorkflowRunDetailApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
||||
result = unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
|
||||
|
||||
assert result["id"] == mock_run.id
|
||||
assert result["status"] == "succeeded"
|
||||
@ -590,7 +589,7 @@ class TestWorkflowRunDetailApiGet:
|
||||
with app.test_request_context("/workflows/run/run-1", method="GET"):
|
||||
api = WorkflowRunDetailApi()
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
_unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
|
||||
unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
|
||||
|
||||
|
||||
class TestWorkflowTaskStopApiPost:
|
||||
@ -613,7 +612,7 @@ class TestWorkflowTaskStopApiPost:
|
||||
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
api = WorkflowTaskStopApi()
|
||||
result = _unwrap(api.post)(
|
||||
result = unwrap(api.post)(
|
||||
api,
|
||||
app_model=mock_workflow_app,
|
||||
end_user=Mock(),
|
||||
@ -635,7 +634,7 @@ class TestWorkflowTaskStopApiPost:
|
||||
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
|
||||
api = WorkflowTaskStopApi()
|
||||
with pytest.raises(NotWorkflowAppError):
|
||||
_unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
|
||||
unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
|
||||
|
||||
|
||||
class TestWorkflowAppLogApiGet:
|
||||
@ -681,6 +680,6 @@ class TestWorkflowAppLogApiGet:
|
||||
):
|
||||
with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory):
|
||||
api = WorkflowAppLogApi()
|
||||
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||
result = unwrap(api.get)(api, app_model=mock_workflow_app)
|
||||
|
||||
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
|
||||
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import sys
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock
|
||||
|
||||
@ -16,7 +17,6 @@ from controllers.service_api.app.error import NotWorkflowAppError
|
||||
from controllers.service_api.app.workflow_events import WorkflowEventsApi
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
@ -34,7 +34,7 @@ def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
|
||||
class TestWorkflowEventsApi:
|
||||
def test_wrong_app_mode(self, app: Flask) -> None:
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
@ -45,8 +45,8 @@ class TestWorkflowEventsApi:
|
||||
def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=None)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
@ -63,8 +63,8 @@ class TestWorkflowEventsApi:
|
||||
)
|
||||
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
@ -90,8 +90,8 @@ class TestWorkflowEventsApi:
|
||||
)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
@ -121,8 +121,8 @@ class TestWorkflowEventsApi:
|
||||
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
|
||||
@ -154,8 +154,8 @@ class TestWorkflowEventsApi:
|
||||
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
|
||||
|
||||
api = WorkflowEventsApi()
|
||||
handler = _unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
handler = unwrap(api.get)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
end_user = SimpleNamespace(id="end-user-1")
|
||||
|
||||
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):
|
||||
|
||||
@ -204,11 +204,3 @@ def mock_child_chunk():
|
||||
child_chunk.tenant_id = str(uuid.uuid4())
|
||||
child_chunk.content = "Test child chunk content"
|
||||
return child_chunk
|
||||
|
||||
|
||||
def _unwrap(method):
|
||||
"""Walk ``__wrapped__`` chain to get the original function."""
|
||||
fn = method
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
return fn
|
||||
|
||||
@ -16,6 +16,7 @@ Decorator strategy:
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from inspect import unwrap
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@ -29,7 +30,6 @@ from controllers.service_api.dataset.metadata import (
|
||||
DatasetMetadataServiceApi,
|
||||
DocumentMetadataEditServiceApi,
|
||||
)
|
||||
from tests.unit_tests.controllers.service_api.conftest import _unwrap
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -65,7 +65,7 @@ class TestDatasetMetadataCreatePost:
|
||||
|
||||
@staticmethod
|
||||
def _call_post(api, **kwargs):
|
||||
return _unwrap(api.post)(api, **kwargs)
|
||||
return unwrap(api.post)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@ -195,7 +195,7 @@ class TestDatasetMetadataServiceApiPatch:
|
||||
|
||||
@staticmethod
|
||||
def _call_patch(api, **kwargs):
|
||||
return _unwrap(api.patch)(api, **kwargs)
|
||||
return unwrap(api.patch)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@ -267,7 +267,7 @@ class TestDatasetMetadataServiceApiDelete:
|
||||
|
||||
@staticmethod
|
||||
def _call_delete(api, **kwargs):
|
||||
return _unwrap(api.delete)(api, **kwargs)
|
||||
return unwrap(api.delete)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@ -376,7 +376,7 @@ class TestDatasetMetadataBuiltInFieldAction:
|
||||
|
||||
@staticmethod
|
||||
def _call_post(api, **kwargs):
|
||||
return _unwrap(api.post)(api, **kwargs)
|
||||
return unwrap(api.post)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
@ -479,7 +479,7 @@ class TestDocumentMetadataEditPost:
|
||||
|
||||
@staticmethod
|
||||
def _call_post(api, **kwargs):
|
||||
return _unwrap(api.post)(api, **kwargs)
|
||||
return unwrap(api.post)(api, **kwargs)
|
||||
|
||||
@patch("controllers.service_api.dataset.metadata.MetadataService")
|
||||
@patch("controllers.service_api.dataset.metadata.DatasetService")
|
||||
|
||||
@ -7,7 +7,7 @@ from models.model import AppMode
|
||||
|
||||
class TestWorkflowAppConfigManager:
|
||||
def test_get_app_config(self):
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
|
||||
workflow = SimpleNamespace(id="wf-1", features_dict={})
|
||||
|
||||
with (
|
||||
|
||||
@ -383,7 +383,7 @@ class TestWorkflowService:
|
||||
|
||||
assert result == mock_workflow
|
||||
|
||||
def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service):
|
||||
def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service: WorkflowService):
|
||||
"""Test get_draft_workflow passes an injected session to published workflow lookup."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
workflow_id = "workflow-123"
|
||||
@ -458,7 +458,7 @@ class TestWorkflowService:
|
||||
|
||||
assert result == mock_workflow
|
||||
|
||||
def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service):
|
||||
def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service: WorkflowService):
|
||||
"""Test get_published_workflow returns None when app has no workflow_id."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
|
||||
|
||||
@ -658,21 +658,21 @@ class TestWorkflowService:
|
||||
# ==================== Workflow Validation Tests ====================
|
||||
# These tests verify graph structure and feature configuration validation
|
||||
|
||||
def test_validate_graph_structure_empty_graph(self, workflow_service):
|
||||
def test_validate_graph_structure_empty_graph(self, workflow_service: WorkflowService):
|
||||
"""Test validate_graph_structure accepts empty graph."""
|
||||
graph = {"nodes": []}
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_graph_structure_valid_graph(self, workflow_service):
|
||||
def test_validate_graph_structure_valid_graph(self, workflow_service: WorkflowService):
|
||||
"""Test validate_graph_structure accepts valid graph."""
|
||||
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
|
||||
|
||||
# Should not raise any exception
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service):
|
||||
def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test validate_graph_structure raises error when start and trigger nodes coexist.
|
||||
|
||||
@ -707,7 +707,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
|
||||
workflow_service.validate_graph_structure(graph)
|
||||
|
||||
def test_validate_features_structure_workflow_mode(self, workflow_service):
|
||||
def test_validate_features_structure_workflow_mode(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test validate_features_structure for workflow mode.
|
||||
|
||||
@ -723,7 +723,7 @@ class TestWorkflowService:
|
||||
tenant_id=app.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
|
||||
def test_validate_features_structure_advanced_chat_mode(self, workflow_service):
|
||||
def test_validate_features_structure_advanced_chat_mode(self, workflow_service: WorkflowService):
|
||||
"""Test validate_features_structure for advanced chat mode."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.ADVANCED_CHAT)
|
||||
features = {"opening_statement": "Hello"}
|
||||
@ -734,7 +734,7 @@ class TestWorkflowService:
|
||||
tenant_id=app.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
|
||||
def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service):
|
||||
def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service: WorkflowService):
|
||||
"""Test validate_features_structure raises error for invalid mode."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION)
|
||||
features = {}
|
||||
@ -767,7 +767,7 @@ class TestWorkflowService:
|
||||
assert workflow.updated_at == "now"
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service):
|
||||
def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service: WorkflowService):
|
||||
"""Test update_draft_workflow_environment_variables raises when draft missing."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
account = TestWorkflowAssociatedDataFactory.create_account_mock()
|
||||
@ -802,7 +802,7 @@ class TestWorkflowService:
|
||||
assert workflow.updated_at == "now"
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
|
||||
def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service):
|
||||
def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service: WorkflowService):
|
||||
"""Test update_draft_workflow_conversation_variables raises when draft missing."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock()
|
||||
account = TestWorkflowAssociatedDataFactory.create_account_mock()
|
||||
@ -917,7 +917,7 @@ class TestWorkflowService:
|
||||
# ==================== Version Management Tests ====================
|
||||
# These tests verify listing and managing published workflow versions
|
||||
|
||||
def test_get_all_published_workflow_with_pagination(self, workflow_service):
|
||||
def test_get_all_published_workflow_with_pagination(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test get_all_published_workflow returns paginated results.
|
||||
|
||||
@ -950,7 +950,7 @@ class TestWorkflowService:
|
||||
assert len(workflows) == 5
|
||||
assert has_more is False
|
||||
|
||||
def test_get_all_published_workflow_has_more(self, workflow_service):
|
||||
def test_get_all_published_workflow_has_more(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test get_all_published_workflow indicates has_more when results exceed limit.
|
||||
|
||||
@ -983,7 +983,7 @@ class TestWorkflowService:
|
||||
assert len(workflows) == 10
|
||||
assert has_more is True
|
||||
|
||||
def test_get_all_published_workflow_no_workflow_id(self, workflow_service):
|
||||
def test_get_all_published_workflow_no_workflow_id(self, workflow_service: WorkflowService):
|
||||
"""Test get_all_published_workflow returns empty when app has no workflow_id."""
|
||||
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
|
||||
mock_session = MagicMock()
|
||||
@ -998,7 +998,7 @@ class TestWorkflowService:
|
||||
# ==================== Update Workflow Tests ====================
|
||||
# These tests verify updating workflow metadata (name, comments, etc.)
|
||||
|
||||
def test_update_workflow_success(self, workflow_service):
|
||||
def test_update_workflow_success(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test update_workflow updates workflow attributes.
|
||||
|
||||
@ -1032,7 +1032,7 @@ class TestWorkflowService:
|
||||
assert mock_workflow.marked_comment == "Updated Comment"
|
||||
assert mock_workflow.updated_by == account_id
|
||||
|
||||
def test_update_workflow_not_found(self, workflow_service):
|
||||
def test_update_workflow_not_found(self, workflow_service: WorkflowService):
|
||||
"""Test update_workflow returns None when workflow not found."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = None
|
||||
@ -1055,7 +1055,7 @@ class TestWorkflowService:
|
||||
# ==================== Delete Workflow Tests ====================
|
||||
# These tests verify workflow deletion with safety checks
|
||||
|
||||
def test_delete_workflow_success(self, workflow_service):
|
||||
def test_delete_workflow_success(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test delete_workflow successfully deletes a published workflow.
|
||||
|
||||
@ -1085,7 +1085,7 @@ class TestWorkflowService:
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once_with(mock_workflow)
|
||||
|
||||
def test_delete_workflow_draft_raises_error(self, workflow_service):
|
||||
def test_delete_workflow_draft_raises_error(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test delete_workflow raises error when trying to delete draft.
|
||||
|
||||
@ -1109,7 +1109,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow"):
|
||||
workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
|
||||
|
||||
def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service):
|
||||
def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test delete_workflow raises error when workflow is in use by app.
|
||||
|
||||
@ -1132,7 +1132,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
|
||||
workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
|
||||
|
||||
def test_delete_workflow_published_as_tool_raises_error(self, workflow_service):
|
||||
def test_delete_workflow_published_as_tool_raises_error(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test delete_workflow raises error when workflow is published as tool.
|
||||
|
||||
@ -1156,7 +1156,7 @@ class TestWorkflowService:
|
||||
with pytest.raises(WorkflowInUseError, match="published as a tool"):
|
||||
workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
|
||||
|
||||
def test_delete_workflow_not_found_raises_error(self, workflow_service):
|
||||
def test_delete_workflow_not_found_raises_error(self, workflow_service: WorkflowService):
|
||||
"""Test delete_workflow raises error when workflow not found."""
|
||||
workflow_id = "nonexistent"
|
||||
tenant_id = "tenant-456"
|
||||
@ -1175,7 +1175,7 @@ class TestWorkflowService:
|
||||
# ==================== Get Default Block Config Tests ====================
|
||||
# These tests verify retrieval of default node configurations
|
||||
|
||||
def test_get_default_block_configs(self, workflow_service):
|
||||
def test_get_default_block_configs(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test get_default_block_configs returns list of default configs.
|
||||
|
||||
@ -1195,7 +1195,7 @@ class TestWorkflowService:
|
||||
|
||||
assert len(result) > 0
|
||||
|
||||
def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service):
|
||||
def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service: WorkflowService):
|
||||
injected_config = HttpRequestNodeConfig(
|
||||
max_connect_timeout=15,
|
||||
max_read_timeout=25,
|
||||
@ -1234,7 +1234,7 @@ class TestWorkflowService:
|
||||
assert passed_http_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
|
||||
mock_llm_node_class.get_default_config.assert_called_once_with(filters=None)
|
||||
|
||||
def test_get_default_block_config_for_node_type(self, workflow_service):
|
||||
def test_get_default_block_config_for_node_type(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test get_default_block_config returns config for specific node type.
|
||||
|
||||
@ -1258,7 +1258,7 @@ class TestWorkflowService:
|
||||
assert result == mock_config
|
||||
mock_node_class.get_default_config.assert_called_once()
|
||||
|
||||
def test_get_default_block_config_invalid_node_type(self, workflow_service):
|
||||
def test_get_default_block_config_invalid_node_type(self, workflow_service: WorkflowService):
|
||||
"""Test get_default_block_config returns empty dict for invalid node type."""
|
||||
with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping:
|
||||
mock_mapping.return_value = {}
|
||||
@ -1268,7 +1268,7 @@ class TestWorkflowService:
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_get_default_block_config_http_request_injects_default_config(self, workflow_service):
|
||||
def test_get_default_block_config_http_request_injects_default_config(self, workflow_service: WorkflowService):
|
||||
injected_config = HttpRequestNodeConfig(
|
||||
max_connect_timeout=11,
|
||||
max_read_timeout=22,
|
||||
@ -1299,7 +1299,7 @@ class TestWorkflowService:
|
||||
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
|
||||
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
|
||||
|
||||
def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service):
|
||||
def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service: WorkflowService):
|
||||
provided_config = HttpRequestNodeConfig(
|
||||
max_connect_timeout=13,
|
||||
max_read_timeout=23,
|
||||
@ -1330,7 +1330,9 @@ class TestWorkflowService:
|
||||
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
|
||||
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is provided_config
|
||||
|
||||
def test_get_default_block_config_http_request_malformed_config_raises_type_error(self, workflow_service):
|
||||
def test_get_default_block_config_http_request_malformed_config_raises_type_error(
|
||||
self, workflow_service: WorkflowService
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"services.workflow_service.get_node_type_classes_mapping",
|
||||
@ -1347,7 +1349,7 @@ class TestWorkflowService:
|
||||
# ==================== Workflow Conversion Tests ====================
|
||||
# These tests verify converting basic apps to workflow apps
|
||||
|
||||
def test_convert_to_workflow_from_chat_app(self, workflow_service):
|
||||
def test_convert_to_workflow_from_chat_app(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test convert_to_workflow converts chat app to workflow.
|
||||
|
||||
@ -1374,7 +1376,7 @@ class TestWorkflowService:
|
||||
assert result == mock_new_app
|
||||
mock_converter.convert_to_workflow.assert_called_once()
|
||||
|
||||
def test_convert_to_workflow_from_completion_app(self, workflow_service):
|
||||
def test_convert_to_workflow_from_completion_app(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test convert_to_workflow converts completion app to workflow.
|
||||
|
||||
@ -1395,7 +1397,7 @@ class TestWorkflowService:
|
||||
|
||||
assert result == mock_new_app
|
||||
|
||||
def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service):
|
||||
def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service: WorkflowService):
|
||||
"""
|
||||
Test convert_to_workflow raises error for invalid app mode.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user