chore: DI current_user && use inspect (#37084)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2026-06-09 14:06:28 +09:00 committed by GitHub
parent bbdf3d7634
commit d11e4eeaf7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 576 additions and 757 deletions

View File

@ -21,7 +21,12 @@ from controllers.common.schema import (
from controllers.console import console_ns
from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
with_current_user,
)
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -54,7 +59,7 @@ from libs import helper
from libs.datetime_utils import naive_utc_now
from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value
from libs.login import current_account_with_tenant, login_required
from models import App
from models import Account, App
from models.model import AppMode
from models.workflow import Workflow
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
@ -401,13 +406,12 @@ class DraftWorkflowApi(Resource):
)
@console_ns.response(400, "Invalid workflow configuration")
@console_ns.response(403, "Permission denied")
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Sync draft workflow
"""
current_user, _ = current_account_with_tenant()
content_type = request.headers.get("Content-Type", "")
if "application/json" in content_type:
@ -468,13 +472,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Run draft workflow
"""
current_user, _ = current_account_with_tenant()
args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
@ -514,12 +517,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
@ -552,12 +555,12 @@ class WorkflowDraftRunIterationNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Run draft workflow iteration node
"""
current_user, _ = current_account_with_tenant()
args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
@ -590,12 +593,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
@ -628,12 +631,12 @@ class WorkflowDraftRunLoopNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Run draft workflow loop node
"""
current_user, _ = current_account_with_tenant()
args = LoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
@ -695,12 +698,12 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
@ -724,12 +727,12 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
result = workflow_service.submit_human_input_form_preview(
@ -753,12 +756,12 @@ class WorkflowDraftHumanInputFormPreviewApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
@ -782,12 +785,12 @@ class WorkflowDraftHumanInputFormRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
result = workflow_service.submit_human_input_form_preview(
@ -811,12 +814,12 @@ class WorkflowDraftHumanInputDeliveryTestApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Test human input delivery
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
workflow_service.test_human_input_delivery(
@ -841,12 +844,12 @@ class DraftWorkflowRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Run draft workflow
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
external_trace_id = get_external_trace_id(request)
@ -911,12 +914,12 @@ class DraftWorkflowNodeRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Run draft workflow node
"""
current_user, _ = current_account_with_tenant()
args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {})
args = args_model.model_dump(exclude_none=True)
@ -981,12 +984,12 @@ class PublishedWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Publish workflow
"""
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
@ -1083,14 +1086,14 @@ class ConvertToWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Convert basic mode of chatbot app to workflow mode
Convert expert mode of chatbot app to workflow mode
Convert Completion App to Workflow App
"""
current_user, _ = current_account_with_tenant()
payload = console_ns.payload or {}
args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True)
@ -1122,9 +1125,9 @@ class WorkflowFeaturesApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
def post(self, current_user: Account, app_model: App):
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
features = args.features
@ -1150,12 +1153,12 @@ class PublishedAllWorkflowApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def get(self, app_model: App):
def get(self, current_user: Account, app_model: App):
"""
Get published workflows
"""
current_user, _ = current_account_with_tenant()
args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True))
page = args.page
@ -1199,9 +1202,9 @@ class DraftWorkflowRestoreApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, workflow_id: str):
current_user, _ = current_account_with_tenant()
def post(self, current_user: Account, app_model: App, workflow_id: str):
workflow_service = WorkflowService()
try:
@ -1237,12 +1240,12 @@ class WorkflowByIdApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def patch(self, app_model: App, workflow_id: str):
def patch(self, current_user: Account, app_model: App, workflow_id: str):
"""
Update workflow attributes
"""
current_user, _ = current_account_with_tenant()
args = WorkflowUpdatePayload.model_validate(console_ns.payload or {})
# Prepare update data
@ -1355,12 +1358,12 @@ class DraftWorkflowTriggerRunApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Poll for trigger events and execute full workflow when event arrives
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {})
node_id = args.node_id
workflow_service = WorkflowService()
@ -1419,12 +1422,12 @@ class DraftWorkflowTriggerNodeApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App, node_id: str):
def post(self, current_user: Account, app_model: App, node_id: str):
"""
Poll for trigger events and execute single node when event arrives
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
draft_workflow = workflow_service.get_draft_workflow(app_model)
@ -1499,12 +1502,12 @@ class DraftWorkflowTriggerRunAllApi(Resource):
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@with_current_user
@edit_permission_required
def post(self, app_model: App):
def post(self, current_user: Account, app_model: App):
"""
Full workflow debug when the start node is a trigger
"""
current_user, _ = current_account_with_tenant()
args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {})
node_ids = args.node_ids

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import uuid
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
@ -68,15 +69,6 @@ from tests.test_containers_integration_tests.controllers.console.helpers import
)
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _make_account() -> Account:
account = Account(
name="tester",
@ -108,7 +100,7 @@ class TestCompletionEndpoints:
def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
completion_module.AppGenerateService,
@ -125,13 +117,13 @@ class TestCompletionEndpoints:
"/",
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
resp = method(_make_account(), app_model=MagicMock(id="app-1"))
resp = method(api, _make_account(), app_model=MagicMock(id="app-1"))
assert resp == {"result": {"text": "ok"}}
def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
completion_module.AppGenerateService,
@ -146,11 +138,11 @@ class TestCompletionEndpoints:
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(NotFound):
method(_make_account(), app_model=MagicMock(id="app-1"))
method(api, _make_account(), app_model=MagicMock(id="app-1"))
def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
completion_module.AppGenerateService,
@ -163,11 +155,11 @@ class TestCompletionEndpoints:
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderNotInitializeError):
method(_make_account(), app_model=MagicMock(id="app-1"))
method(api, _make_account(), app_model=MagicMock(id="app-1"))
def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = completion_module.CompletionMessageApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
completion_module.AppGenerateService,
@ -180,7 +172,7 @@ class TestCompletionEndpoints:
json={"inputs": {}, "model_config": {}, "query": "hi"},
):
with pytest.raises(completion_module.ProviderQuotaExceededError):
method(_make_account(), app_model=MagicMock(id="app-1"))
method(api, _make_account(), app_model=MagicMock(id="app-1"))
class TestAppEndpoints:
@ -190,7 +182,7 @@ class TestAppEndpoints:
def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = app_module.AppApi()
method = _unwrap(api.put)
method = unwrap(api.put)
payload = {
"name": "Updated App",
"description": "Updated description",
@ -209,7 +201,7 @@ class TestAppEndpoints:
app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload),
patch.object(type(console_ns), "payload", payload),
):
response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI))
response = method(api, app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI))
assert response == {"id": "app-1"}
assert app_service.update_app.call_args.args[1]["icon_type"] is None
@ -228,7 +220,7 @@ class TestAppEndpoints:
def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = app_module.AppIconApi()
method = _unwrap(api.post)
method = unwrap(api.post)
payload = {
"icon": "https://example.com/icon.png",
"icon_type": "image",
@ -246,7 +238,7 @@ class TestAppEndpoints:
app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload),
patch.object(type(console_ns), "payload", payload),
):
response = method(app_model=SimpleNamespace())
response = method(api, app_model=SimpleNamespace())
assert response == {"id": "app-1"}
assert app_service.update_app_icon.call_args.args[1:] == (
@ -300,7 +292,7 @@ class TestOpsTraceEndpoints:
def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(
ops_trace_module.OpsService,
@ -309,13 +301,13 @@ class TestOpsTraceEndpoints:
)
with app.test_request_context("/?tracing_provider=langfuse"):
result = method(app_model=MagicMock(id="app-1"))
result = method(api, app_model=MagicMock(id="app-1"))
assert result == {"has_not_configured": True}
def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
ops_trace_module.OpsService,
@ -328,11 +320,11 @@ class TestOpsTraceEndpoints:
json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}},
):
with pytest.raises(BadRequest):
method(app_model=MagicMock(id="app-1"))
method(api, app_model=MagicMock(id="app-1"))
def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = ops_trace_module.TraceAppConfigApi()
method = _unwrap(api.delete)
method = unwrap(api.delete)
monkeypatch.setattr(
ops_trace_module.OpsService,
@ -342,7 +334,7 @@ class TestOpsTraceEndpoints:
with app.test_request_context("/?tracing_provider=langfuse"):
with pytest.raises(BadRequest):
method(app_model=MagicMock(id="app-1"))
method(api, app_model=MagicMock(id="app-1"))
class TestSiteEndpoints:
@ -360,7 +352,7 @@ class TestSiteEndpoints:
def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = site_module.AppSite()
method = _unwrap(api.post)
method = unwrap(api.post)
site = MagicMock()
site.app_id = "app-1"
@ -386,14 +378,14 @@ class TestSiteEndpoints:
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/", json={"title": "My Site"}):
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
result = method(api, SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
assert isinstance(result, dict)
assert result["title"] == "My Site"
def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = site_module.AppSiteAccessTokenReset()
method = _unwrap(api.post)
method = unwrap(api.post)
site = MagicMock()
site.app_id = "app-1"
@ -420,7 +412,7 @@ class TestSiteEndpoints:
monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now")
with app.test_request_context("/"):
result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
result = method(api, SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1"))
assert isinstance(result, dict)
assert result["access_token"] == "code"
@ -451,7 +443,7 @@ class TestWorkflowAppLogEndpoints:
def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_app_log_module.WorkflowAppLogApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock()))
@ -481,7 +473,7 @@ class TestWorkflowAppLogEndpoints:
)
with app.test_request_context("/?page=1&limit=20"):
result = method(app_model=SimpleNamespace(id="app-1"))
result = method(api, app_model=SimpleNamespace(id="app-1"))
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
@ -497,7 +489,7 @@ class TestWorkflowDraftVariableEndpoints:
def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_draft_variable_module.WorkflowVariableCollectionApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock()))
@ -532,7 +524,7 @@ class TestWorkflowDraftVariableEndpoints:
monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService)
with app.test_request_context("/?page=1&limit=20"):
result = method(_make_account(), app_model=SimpleNamespace(id="app-1"))
result = method(api, _make_account(), app_model=SimpleNamespace(id="app-1"))
assert result == {"items": [], "total": 0}
@ -565,10 +557,12 @@ class TestWorkflowStatisticEndpoints:
)
api = workflow_statistic_module.WorkflowDailyRunsStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
with app.test_request_context("/"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
response = method(
api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")
)
assert response.get_json() == {"data": [{"date": "2024-01-01"}]}
@ -588,10 +582,12 @@ class TestWorkflowStatisticEndpoints:
)
api = workflow_statistic_module.WorkflowDailyTerminalsStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
with app.test_request_context("/"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1"))
response = method(
api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")
)
assert response.get_json() == {"data": [{"date": "2024-01-02"}]}
@ -610,7 +606,7 @@ class TestWorkflowTriggerEndpoints:
def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch):
api = workflow_trigger_module.WebhookTriggerApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock()))
@ -635,7 +631,7 @@ class TestWorkflowTriggerEndpoints:
monkeypatch.setattr(workflow_trigger_module, "sessionmaker", DummySessionMaker)
with app.test_request_context("/?node_id=node-1"):
result = method(app_model=SimpleNamespace(id="app-1"))
result = method(api, app_model=SimpleNamespace(id="app-1"))
assert isinstance(result, dict)
assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys())

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock
@ -12,15 +13,6 @@ from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
@ -42,7 +34,7 @@ class TestAppImportApi:
def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
@ -52,14 +44,14 @@ class TestAppImportApi:
)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
assert status == 400
assert response["status"] == ImportStatus.FAILED
def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
@ -69,14 +61,14 @@ class TestAppImportApi:
)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
assert status == 202
assert response["status"] == ImportStatus.PENDING
def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=True)
monkeypatch.setattr(
@ -88,7 +80,7 @@ class TestAppImportApi:
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
update_access.assert_called_once_with("app-123", "private")
assert status == 200
@ -96,7 +88,7 @@ class TestAppImportApi:
def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
@ -111,7 +103,7 @@ class TestAppImportApi:
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
fake_session.commit.assert_called_once_with()
fake_session.rollback.assert_not_called()
@ -120,7 +112,7 @@ class TestAppImportApi:
def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=False)
monkeypatch.setattr(
@ -135,7 +127,7 @@ class TestAppImportApi:
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
fake_session.rollback.assert_called_once_with()
fake_session.commit.assert_not_called()
@ -150,7 +142,7 @@ class TestAppImportConfirmApi:
def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportConfirmApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
app_import_module.AppDslService,
@ -159,7 +151,7 @@ class TestAppImportConfirmApi:
)
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
response, status = method(api, SimpleNamespace(id="u1"), import_id="import-1")
assert status == 400
assert response["status"] == ImportStatus.FAILED
@ -172,7 +164,7 @@ class TestAppImportCheckDependenciesApi:
def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = app_import_module.AppImportCheckDependenciesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(
app_import_module.AppDslService,
@ -181,7 +173,7 @@ class TestAppImportCheckDependenciesApi:
)
with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"):
response, status = method(app_model=SimpleNamespace(id="app-1"))
response, status = method(api, app_model=SimpleNamespace(id="app-1"))
assert status == 200
assert response["leaked_dependencies"] == []

View File

@ -239,13 +239,7 @@ class TestTagUnbindingPayload:
# Helpers
# ---------------------------------------------------------------------------
def _unwrap(method):
"""Walk ``__wrapped__`` chain to get the original function."""
fn = method
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return fn
from inspect import unwrap
@pytest.fixture
@ -499,7 +493,7 @@ class TestDatasetListApiPost:
json={"name": "New Dataset"},
):
api = DatasetListApi()
response, status = _unwrap(api.post)(api, tenant_id=mock_tenant.id)
response, status = unwrap(api.post)(api, tenant_id=mock_tenant.id)
assert status == 200
assert_dataset_detail_shape(response)
@ -527,7 +521,7 @@ class TestDatasetListApiPost:
):
api = DatasetListApi()
with pytest.raises(DatasetNameDuplicateError):
_unwrap(api.post)(api, tenant_id=mock_tenant.id)
unwrap(api.post)(api, tenant_id=mock_tenant.id)
# ---------------------------------------------------------------------------
@ -720,7 +714,7 @@ class TestDatasetApiPatch:
json=payload,
):
api = DatasetApi()
response, status = _unwrap(api.patch)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
response, status = unwrap(api.patch)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
assert status == 200
assert_dataset_detail_shape(response, with_partial_members=True)
@ -760,7 +754,7 @@ class TestDatasetApiDelete:
method="DELETE",
):
api = DatasetApi()
result = _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
result = unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
assert result == ("", 204)
@ -783,7 +777,7 @@ class TestDatasetApiDelete:
):
api = DatasetApi()
with pytest.raises(NotFound):
_unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
@patch("controllers.service_api.dataset.dataset.current_user")
@patch("controllers.service_api.dataset.dataset.DatasetService")
@ -804,7 +798,7 @@ class TestDatasetApiDelete:
):
api = DatasetApi()
with pytest.raises(DatasetInUseError):
_unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id)
# ---------------------------------------------------------------------------

View File

@ -19,11 +19,7 @@ def app(flask_app_with_containers) -> Flask:
return flask_app_with_containers
def _unwrap(method):
fn = method
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return fn
from inspect import unwrap
def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant:
@ -76,7 +72,7 @@ class TestAppSiteApi:
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
response = _unwrap(api.get)(api, app_model=app_model)
response = unwrap(api.get)(api, app_model=app_model)
assert response["title"] == "Service API Site"
assert response["icon"] == "robot"
@ -89,7 +85,7 @@ class TestAppSiteApi:
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
_unwrap(api.get)(api, app_model=app_model)
unwrap(api.get)(api, app_model=app_model)
def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None:
tenant = _create_tenant(db_session_with_containers)
@ -107,4 +103,4 @@ class TestAppSiteApi:
with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}):
api = AppSiteApi()
with pytest.raises(Forbidden):
_unwrap(api.get)(api, app_model=app_model)
unwrap(api.get)(api, app_model=app_model)

View File

@ -62,7 +62,7 @@ def _create_app_with_site(session: Session) -> tuple[App, Account]:
tenant_id=tenant.id,
name="Test App",
description="",
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
icon_type="emoji",
icon="app",
icon_background="#ffffff",

View File

@ -205,7 +205,7 @@ class TestHumanInputResumeNodeExecutionIntegration:
tenant_id=tenant.id,
name="Test App",
description="",
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
icon_type=IconType.EMOJI.value,
icon="rocket",
icon_background="#4ECDC4",

View File

@ -75,7 +75,7 @@ def _app_stub(**overrides: Any) -> App:
defaults = {
"id": str(uuid4()),
"tenant_id": _DEFAULT_TENANT_ID,
"mode": AppMode.WORKFLOW.value,
"mode": AppMode.WORKFLOW,
"name": "n",
"description": "d",
"icon_type": IconType.EMOJI,
@ -528,7 +528,7 @@ class TestAppDslService:
created_app = SimpleNamespace(
id=str(uuid4()),
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
tenant_id=_DEFAULT_TENANT_ID,
)
monkeypatch.setattr(
@ -707,7 +707,7 @@ class TestAppDslService:
)
app = _app_stub(
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
name="old",
description="old-desc",
icon_type=IconType.EMOJI,
@ -721,7 +721,7 @@ class TestAppDslService:
app=app,
data={
"app": {
"mode": AppMode.WORKFLOW.value,
"mode": AppMode.WORKFLOW,
"name": "yaml-name",
"icon_type": IconType.IMAGE,
"icon": "X",
@ -747,7 +747,7 @@ class TestAppDslService:
with pytest.raises(ValueError, match="Current tenant is not set"):
service._create_or_update_app(
app=None,
data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}},
data={"app": {"mode": AppMode.WORKFLOW, "name": "n"}},
account=account,
)
@ -772,7 +772,7 @@ class TestAppDslService:
)
]
data = {
"app": {"mode": AppMode.WORKFLOW.value, "name": "n"},
"app": {"mode": AppMode.WORKFLOW, "name": "n"},
"workflow": {
"graph": {"nodes": []},
"features": {},
@ -792,8 +792,8 @@ class TestAppDslService:
service = AppDslService(db_session_with_containers)
with pytest.raises(ValueError, match="Missing workflow data"):
service._create_or_update_app(
app=_app_stub(mode=AppMode.WORKFLOW.value),
data={"app": {"mode": AppMode.WORKFLOW.value}},
app=_app_stub(mode=AppMode.WORKFLOW),
data={"app": {"mode": AppMode.WORKFLOW}},
account=_account_mock(),
)
@ -852,7 +852,7 @@ class TestAppDslService:
)
workflow_app = _app_stub(
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
icon_type="emoji",
)
AppDslService.export_dsl(workflow_app)
@ -874,7 +874,7 @@ class TestAppDslService:
)
emoji_app = _app_stub(
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
name="Emoji App",
icon="🎨",
icon_type=IconType.EMOJI,
@ -889,7 +889,7 @@ class TestAppDslService:
assert data["app"]["icon_background"] == "#FF5733"
image_app = _app_stub(
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
name="Image App",
icon="https://example.com/icon.png",
icon_type=IconType.IMAGE,

View File

@ -50,7 +50,7 @@ def _create_app_with_draft_workflow(
tenant_id=tenant.id,
name="Test App",
description="",
mode=AppMode.WORKFLOW.value,
mode=AppMode.WORKFLOW,
icon_type="emoji",
icon="app",
icon_background="#ffffff",

View File

@ -1,3 +1,4 @@
from inspect import unwrap
from types import SimpleNamespace
from typing import Protocol, cast
@ -26,12 +27,6 @@ from controllers.console.agent.roster import (
from services.entities.agent_entities import ComposerSaveStrategy, ComposerVariant
def _unwrap(method):
while hasattr(method, "__wrapped__"):
method = method.__wrapped__
return method
def _agent_response(agent_id: str = "agent-1") -> dict:
return {
"id": agent_id,
@ -135,7 +130,7 @@ def test_roster_list_get_parses_query_and_calls_service(app: Flask, monkeypatch:
monkeypatch.setattr(roster_controller.AgentRosterService, "list_roster_agents", list_roster_agents)
with app.test_request_context("/console/api/agents?page=2&limit=5&keyword=analyst"):
result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1")
result = unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1")
assert result["page"] == 2
assert captured == {"tenant_id": "tenant-1", "page": 2, "limit": 5, "keyword": "analyst"}
@ -157,7 +152,7 @@ def test_roster_list_post_creates_agent_and_returns_detail(
)
with app.test_request_context(json={"name": "Analyst", "agent_soul": {"prompt": {"system_prompt": "x"}}}):
result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id)
result, status = unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id)
assert status == 201
assert result["id"] == "agent-1"
@ -174,7 +169,7 @@ def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.Monkey
monkeypatch.setattr(roster_controller.AgentRosterService, "list_invite_options", list_invite_options)
with app.test_request_context("/console/api/agents/invite-options?page=1&limit=10&app_id=app-1"):
result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1")
result = unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1")
assert result == {"data": [], "page": 1, "limit": 10, "total": 0, "has_more": False}
assert captured == {"tenant_id": "tenant-1", "page": 1, "limit": 10, "keyword": None, "app_id": "app-1"}
@ -232,19 +227,19 @@ def test_roster_detail_patch_delete_and_versions_call_services(
},
)
assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id
assert unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id
with app.test_request_context(json={"description": "updated"}):
assert (
_unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"]
unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"]
== "updated"
)
assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204)
assert unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204)
assert archived["account_id"] == "account-1"
assert (
_unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"]
== "version-1"
)
version_detail = _unwrap(AgentRosterVersionDetailApi.get)(
version_detail = unwrap(AgentRosterVersionDetailApi.get)(
AgentRosterVersionDetailApi(), "tenant-1", agent_id, version_id
)
assert version_detail["id"] == version_id
@ -286,27 +281,27 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save(
},
)
workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1")
workflow_state = unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1")
assert workflow_state["node_id"] == "node-1"
with app.test_request_context(json=payload):
saved_state = _unwrap(WorkflowAgentComposerApi.put)(
saved_state = unwrap(WorkflowAgentComposerApi.put)(
WorkflowAgentComposerApi(), "tenant-1", account_id, app_model, "node-1"
)
assert saved_state["save_options"] == ["node_job_only"]
assert _unwrap(WorkflowAgentComposerValidateApi.post)(
assert unwrap(WorkflowAgentComposerValidateApi.post)(
WorkflowAgentComposerValidateApi(), app_model, "node-1"
) == {"result": "success", "errors": []}
assert (
_unwrap(WorkflowAgentComposerCandidatesApi.get)(WorkflowAgentComposerCandidatesApi(), app_model, "node-1")[
unwrap(WorkflowAgentComposerCandidatesApi.get)(WorkflowAgentComposerCandidatesApi(), app_model, "node-1")[
"variant"
]
== "workflow"
)
with app.test_request_context(json=payload):
assert _unwrap(WorkflowAgentComposerImpactApi.post)(
assert unwrap(WorkflowAgentComposerImpactApi.post)(
WorkflowAgentComposerImpactApi(), "tenant-1", app_model, "node-1"
) == {"current_snapshot_id": "version-1", "workflow_node_count": 1, "bindings": []}
assert _unwrap(WorkflowAgentComposerSaveToRosterApi.post)(
assert unwrap(WorkflowAgentComposerSaveToRosterApi.post)(
WorkflowAgentComposerSaveToRosterApi(), "tenant-1", account_id, app_model, "node-1"
)["save_options"] == ["node_job_only"]
@ -315,7 +310,7 @@ def test_workflow_impact_returns_empty_without_version(app: Flask) -> None:
payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value}
with app.test_request_context(json=payload):
result = _unwrap(WorkflowAgentComposerImpactApi.post)(
result = unwrap(WorkflowAgentComposerImpactApi.post)(
WorkflowAgentComposerImpactApi(), "tenant-1", SimpleNamespace(id="app-1"), "node-1"
)
@ -348,15 +343,15 @@ def test_agent_app_composer_get_put_validate_and_candidates(
lambda **kwargs: _candidates_response("agent_app"),
)
assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app"
assert unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app"
with app.test_request_context(json=payload):
assert (
_unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"]
unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"]
== "agent_app"
)
assert _unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == {
assert unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == {
"result": "success",
"errors": [],
}
agent_app_candidates = _unwrap(AgentAppComposerCandidatesApi.get)(AgentAppComposerCandidatesApi(), app_model)
agent_app_candidates = unwrap(AgentAppComposerCandidatesApi.get)(AgentAppComposerCandidatesApi(), app_model)
assert agent_app_candidates["variant"] == "agent_app"

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock
@ -12,15 +13,6 @@ from controllers.console.app import app_import as app_import_module
from services.app_dsl_service import ImportStatus
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _Result:
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
self.status = status
@ -52,7 +44,7 @@ class TestAppImportApi:
def test_import_post_returns_failed_status_and_rolls_back(
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=False)
session = _mock_session(monkeypatch)
@ -63,7 +55,7 @@ class TestAppImportApi:
)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
session.rollback.assert_called_once_with()
session.commit.assert_not_called()
@ -73,7 +65,7 @@ class TestAppImportApi:
def test_import_post_returns_pending_status_and_commits(
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=False)
session = _mock_session(monkeypatch)
@ -84,7 +76,7 @@ class TestAppImportApi:
)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
@ -94,7 +86,7 @@ class TestAppImportApi:
def test_import_post_updates_webapp_auth_when_enabled(
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
method = unwrap(api.post)
_install_features(monkeypatch, enabled=True)
session = _mock_session(monkeypatch)
@ -107,7 +99,7 @@ class TestAppImportApi:
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
response, status = method(SimpleNamespace(id="u1"))
response, status = method(api, SimpleNamespace(id="u1"))
session.commit.assert_called_once_with()
session.rollback.assert_not_called()
@ -124,7 +116,7 @@ class TestAppImportConfirmApi:
def test_import_confirm_returns_failed_status_and_rolls_back(
self, api, app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
method = _unwrap(api.post)
method = unwrap(api.post)
session = _mock_session(monkeypatch)
monkeypatch.setattr(
@ -134,7 +126,7 @@ class TestAppImportConfirmApi:
)
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
response, status = method(SimpleNamespace(id="u1"), import_id="import-1")
response, status = method(api, SimpleNamespace(id="u1"), import_id="import-1")
session.rollback.assert_called_once_with()
session.commit.assert_not_called()

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import io
from inspect import unwrap
from types import SimpleNamespace
import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import InternalServerError
@ -32,27 +34,18 @@ from services.errors.audio import (
)
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_console_audio_api_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
response = handler(app_model=app_model)
response = handler(api, app_model=app_model)
assert response == {"text": "ok"}
@ -71,33 +64,33 @@ def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None
(InvokeError("invoke"), CompletionRequestError),
],
)
def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
def test_console_audio_api_error_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(expected):
handler(app_model=app_model)
handler(api, app_model=app_model)
def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_console_audio_api_unhandled_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")))
api = ChatMessageAudioApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}):
with pytest.raises(InternalServerError):
handler(app_model=app_model)
handler(api, app_model=app_model)
def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_console_text_api_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = ChatMessageTextApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
@ -105,16 +98,16 @@ def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
method="POST",
json={"text": "hello", "voice": "v"},
):
response = handler(app_model=app_model)
response = handler(api, app_model=app_model)
assert response == {"audio": "ok"}
def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_console_text_api_error_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError()))
api = ChatMessageTextApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
with app.test_request_context(
@ -123,23 +116,23 @@ def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) ->
json={"text": "hello"},
):
with pytest.raises(ProviderQuotaExceededError):
handler(app_model=app_model)
handler(api, app_model=app_model)
def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_console_text_modes_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
api = TextModesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
response = handler(app_model=app_model)
response = handler(api, app_model=app_model)
assert response == ["voice-1"]
def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_console_text_modes_language_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
AudioService,
"transcript_tts_voices",
@ -147,17 +140,17 @@ def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch)
)
api = TextModesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(tenant_id="t1")
with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"):
with pytest.raises(AppUnavailableError):
handler(app_model=app_model)
handler(api, app_model=app_model)
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_audio_to_text_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
method = unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload)
@ -171,14 +164,14 @@ def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_audio_to_text_maps_audio_too_large(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
AudioService,
@ -196,12 +189,12 @@ def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
method(api, app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
@ -212,14 +205,14 @@ def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_voices_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
@ -230,14 +223,14 @@ def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> N
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_audio_to_text_with_invalid_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageAudioApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
@ -251,13 +244,13 @@ def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_with_language_param(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = ChatMessageTextApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
@ -268,13 +261,13 @@ def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch)
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_voices_with_language_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = TextModesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(
AudioService,
@ -288,5 +281,5 @@ def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.Monk
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert isinstance(response, list)

View File

@ -1,27 +1,20 @@
from __future__ import annotations
import io
from inspect import unwrap
from types import SimpleNamespace
import pytest
from flask import Flask
from controllers.console.app import audio as audio_module
from controllers.console.app.error import AudioTooLargeError
from services.errors.audio import AudioTooLargeServiceError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_audio_to_text_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
method = unwrap(api.post)
response_payload = {"text": "hello"}
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload)
@ -35,14 +28,14 @@ def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
data=data,
content_type="multipart/form-data",
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == response_payload
def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_audio_to_text_maps_audio_too_large(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(
audio_module.AudioService,
@ -60,12 +53,12 @@ def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch
content_type="multipart/form-data",
):
with pytest.raises(AudioTooLargeError):
method(app_model=app_model)
method(api, app_model=app_model)
def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
@ -76,14 +69,14 @@ def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
method="POST",
json={"text": "hello"},
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == {"audio": "ok"}
def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_voices_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"])
@ -94,14 +87,14 @@ def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> N
method="GET",
query_string={"language": "en-US"},
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == ["voice-1"]
def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_audio_to_text_with_invalid_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageAudioApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"})
@ -115,13 +108,13 @@ def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -
content_type="multipart/form-data",
):
# Should not raise, AudioService is mocked
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == {"text": "test"}
def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_with_language_param(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.ChatMessageTextApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"})
@ -132,13 +125,13 @@ def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch)
method="POST",
json={"text": "hello", "language": "en-US"},
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert response == {"audio": "test"}
def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_text_to_audio_voices_with_language_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = audio_module.TextModesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
monkeypatch.setattr(
audio_module.AudioService,
@ -152,5 +145,5 @@ def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.Monk
"/console/api/apps/app-1/text-to-audio/voices?language=en-US",
method="GET",
):
response = method(app_model=app_model)
response = method(api, app_model=app_model)
assert isinstance(response, list)

View File

@ -1,9 +1,11 @@
from __future__ import annotations
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from werkzeug.exceptions import BadRequest, NotFound
from controllers.console.app import conversation as conversation_module
@ -11,22 +13,13 @@ from models.model import AppMode
from services.errors.conversation import ConversationNotExistsError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _make_account():
return SimpleNamespace(timezone="UTC", id="u1")
def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_completion_conversation_list_returns_paginated_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
method = unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
@ -40,14 +33,14 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch:
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"):
response = method(account, app_model=SimpleNamespace(id="app-1"))
response = method(api, account, app_model=SimpleNamespace(id="app-1"))
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_completion_conversation_list_invalid_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationApi()
method = _unwrap(api.get)
method = unwrap(api.get)
account = _make_account()
monkeypatch.setattr(
@ -62,12 +55,12 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes
query_string={"start": "bad"},
):
with pytest.raises(BadRequest):
method(account, app_model=SimpleNamespace(id="app-1"))
method(api, account, app_model=SimpleNamespace(id="app-1"))
def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_chat_conversation_list_advanced_chat_calls_paginate(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.ChatConversationApi()
method = _unwrap(api.get)
method = unwrap(api.get)
account = _make_account()
monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None))
@ -81,7 +74,7 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p
monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result)
with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"):
response = method(account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
response = method(api, account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT))
assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}
@ -114,7 +107,7 @@ def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPat
def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_module.CompletionConversationDetailApi()
method = _unwrap(api.delete)
method = unwrap(api.delete)
monkeypatch.setattr(
conversation_module.ConversationService,
@ -123,4 +116,4 @@ def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.Monke
)
with pytest.raises(NotFound):
method(_make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1")
method(api, _make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1")

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from contextlib import nullcontext
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
import pytest
@ -12,18 +13,9 @@ from controllers.console.app import conversation_variables as conversation_varia
from graphon.variables.types import SegmentType
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_get_conversation_variables_returns_paginated_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
created_at = datetime(2026, 1, 1, tzinfo=UTC)
updated_at = datetime(2026, 1, 2, tzinfo=UTC)
@ -53,7 +45,7 @@ def test_get_conversation_variables_returns_paginated_response(app: Flask, monke
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(api, app_model=SimpleNamespace(id="app-1"))
assert response["page"] == 1
assert response["limit"] == 100
@ -68,7 +60,7 @@ def test_get_conversation_variables_normalizes_value_type_and_value(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
row = SimpleNamespace(
created_at=None,
@ -96,7 +88,7 @@ def test_get_conversation_variables_normalizes_value_type_and_value(
method="GET",
query_string={"conversation_id": "conv-1"},
):
response = method(app_model=SimpleNamespace(id="app-1"))
response = method(api, app_model=SimpleNamespace(id="app-1"))
assert response["data"][0]["value_type"] == "number"
assert response["data"][0]["value"] == "42"
@ -104,8 +96,8 @@ def test_get_conversation_variables_normalizes_value_type_and_value(
def test_get_conversation_variables_requires_conversation_id(app) -> None:
api = conversation_variables_module.ConversationVariablesApi()
method = _unwrap(api.get)
method = unwrap(api.get)
with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"):
with pytest.raises(ValidationError):
method(app_model=SimpleNamespace(id="app-1"))
method(api, app_model=SimpleNamespace(id="app-1"))

View File

@ -1,24 +1,17 @@
from __future__ import annotations
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from flask import Flask
from controllers.console.app import generator as generator_module
from controllers.console.app.error import ProviderNotInitializeError
from core.errors.error import ProviderTokenNotInitError
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def _model_config_payload():
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
@ -38,9 +31,9 @@ def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow):
return service
def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_rule_generate_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []})
@ -49,14 +42,14 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
method="POST",
json={"instruction": "do it", "model_config": _model_config_payload()},
):
response = method("t1")
response = method(api, "t1")
assert response == {"rules": []}
def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_rule_code_generate_maps_token_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.RuleCodeGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
def _raise(*_args, **_kwargs):
raise ProviderTokenNotInitError("missing token")
@ -69,12 +62,12 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc
json={"instruction": "do it", "model_config": _model_config_payload()},
):
with pytest.raises(ProviderNotInitializeError):
method("t1")
method(api, "t1")
def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
session = MagicMock()
session.get.return_value = None
@ -89,16 +82,16 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch
"model_config": _model_config_payload(),
},
):
response, status = method(session, "t1")
response, status = method(api, session, "t1")
assert status == 400
assert response["error"] == "app app-1 not found"
session.get.assert_called_once_with(generator_module.App, "app-1")
def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
@ -114,15 +107,15 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey
"model_config": _model_config_payload(),
},
):
response, status = method(session, "t1")
response, status = method(api, session, "t1")
assert status == 400
assert response["error"] == "workflow app-1 not found"
def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_instruction_generate_node_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
@ -140,15 +133,15 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch)
"model_config": _model_config_payload(),
},
):
response, status = method(session, "t1")
response, status = method(api, session, "t1")
assert status == 400
assert response["error"] == "node node-1 not found"
def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_instruction_generate_code_node(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
app_model = SimpleNamespace(id="app-1")
session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model)
@ -173,16 +166,16 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) ->
"model_config": _model_config_payload(),
},
):
response = method(session, "t1")
response = method(api, session, "t1")
assert response == {"code": "x"}
assert workflow_service.app_model is app_model
assert workflow_service.session is session
def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_instruction_generate_legacy_modify(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
session = SimpleNamespace()
monkeypatch.setattr(
@ -202,14 +195,14 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch
"model_config": _model_config_payload(),
},
):
response = method(session, "t1")
response = method(api, session, "t1")
assert response == {"instruction": "ok"}
def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_instruction_generate_incompatible_params(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.InstructionGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
session = SimpleNamespace()
with app.test_request_context(
@ -223,29 +216,29 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke
"model_config": _model_config_payload(),
},
):
response, status = method(session, "t1")
response, status = method(api, session, "t1")
assert status == 400
assert response["error"] == "incompatible parameters"
def test_instruction_template_prompt(app) -> None:
def test_instruction_template_prompt(app: Flask) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
method="POST",
json={"type": "prompt"},
):
response = method()
response = method(api)
assert "data" in response
def test_instruction_template_invalid_type(app) -> None:
def test_instruction_template_invalid_type(app: Flask) -> None:
api = generator_module.InstructionGenerationTemplateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
with app.test_request_context(
"/console/api/instruction-generate/template",
@ -253,7 +246,7 @@ def test_instruction_template_invalid_type(app) -> None:
json={"type": "unknown"},
):
with pytest.raises(ValueError):
method()
method(api)
# ─ /workflow-generate ─────────────────────────────────────────────────────────
@ -281,9 +274,9 @@ def _stub_workflow_service(monkeypatch: pytest.MonkeyPatch, returns=None, raises
monkeypatch.setattr(generator_module.WorkflowGeneratorService, "generate_workflow_graph", _call)
def test_workflow_generate_returns_service_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_returns_service_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
expected = {
"graph": {"nodes": [{"id": "node-1"}], "edges": [], "viewport": {"x": 0, "y": 0, "zoom": 0.7}},
@ -297,16 +290,16 @@ def test_workflow_generate_returns_service_result(app, monkeypatch: pytest.Monke
method="POST",
json=_workflow_generate_payload(),
):
response = method("t1")
response = method(api, "t1")
assert response == expected
def test_workflow_generate_maps_provider_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_maps_provider_token_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""ProviderTokenNotInitError → ProviderNotInitializeError so the frontend
can render the same "provider missing" UX as /rule-generate."""
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_stub_workflow_service(monkeypatch, raises=ProviderTokenNotInitError("missing token"))
@ -316,15 +309,15 @@ def test_workflow_generate_maps_provider_token_error(app, monkeypatch: pytest.Mo
json=_workflow_generate_payload(),
):
with pytest.raises(ProviderNotInitializeError):
method("t1")
method(api, "t1")
def test_workflow_generate_maps_quota_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_maps_quota_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
from controllers.console.app.error import ProviderQuotaExceededError
from core.errors.error import QuotaExceededError
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_stub_workflow_service(monkeypatch, raises=QuotaExceededError())
@ -334,15 +327,15 @@ def test_workflow_generate_maps_quota_error(app, monkeypatch: pytest.MonkeyPatch
json=_workflow_generate_payload(),
):
with pytest.raises(ProviderQuotaExceededError):
method("t1")
method(api, "t1")
def test_workflow_generate_maps_model_not_support_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_maps_model_not_support_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
from controllers.console.app.error import ProviderModelCurrentlyNotSupportError
from core.errors.error import ModelCurrentlyNotSupportError
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_stub_workflow_service(monkeypatch, raises=ModelCurrentlyNotSupportError("not supported"))
@ -352,15 +345,15 @@ def test_workflow_generate_maps_model_not_support_error(app, monkeypatch: pytest
json=_workflow_generate_payload(),
):
with pytest.raises(ProviderModelCurrentlyNotSupportError):
method("t1")
method(api, "t1")
def test_workflow_generate_maps_invoke_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_maps_invoke_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
from controllers.console.app.error import CompletionRequestError
from graphon.model_runtime.errors.invoke import InvokeError
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
_stub_workflow_service(monkeypatch, raises=InvokeError("LLM unreachable"))
@ -370,13 +363,13 @@ def test_workflow_generate_maps_invoke_error(app, monkeypatch: pytest.MonkeyPatc
json=_workflow_generate_payload(),
):
with pytest.raises(CompletionRequestError):
method("t1")
method(api, "t1")
def test_workflow_generate_accepts_advanced_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_accepts_advanced_chat_mode(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""The payload Literal must accept advanced-chat as well as workflow."""
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
captured: dict = {}
@ -397,17 +390,17 @@ def test_workflow_generate_accepts_advanced_chat_mode(app, monkeypatch: pytest.M
method="POST",
json=payload,
):
method("t1")
method(api, "t1")
assert captured["mode"] == "advanced-chat"
assert captured["instruction"] == "Summarize a URL"
assert captured["ideal_output"] == "A 3-sentence summary."
def test_workflow_generate_forwards_current_graph_for_refine(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_forwards_current_graph_for_refine(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""cmd+k `/refine`: the optional current_graph field reaches the service."""
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
captured: dict = {}
@ -429,15 +422,15 @@ def test_workflow_generate_forwards_current_graph_for_refine(app, monkeypatch: p
method="POST",
json=payload,
):
method("t1")
method(api, "t1")
assert captured["current_graph"] == graph
def test_workflow_generate_current_graph_defaults_to_none(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_generate_current_graph_defaults_to_none(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
"""Omitting current_graph (the `/create` path) forwards None to the service."""
api = generator_module.WorkflowGenerateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
captured: dict = {}
@ -456,6 +449,6 @@ def test_workflow_generate_current_graph_defaults_to_none(app, monkeypatch: pyte
method="POST",
json=_workflow_generate_payload(),
):
method("t1")
method(api, "t1")
assert captured["current_graph"] is None

View File

@ -7,15 +7,6 @@ import pytest
from controllers.console.app import message as message_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test valid ChatMessagesQuery with all fields."""
query = message_module.ChatMessagesQuery(

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import json
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock
@ -11,18 +12,9 @@ from controllers.console.app import model_config as model_config_module
from models.model import AppMode, AppModelConfig
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
method = unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
@ -50,7 +42,7 @@ def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method("t1", "u1", app_model=app_model)
response = method(api, "t1", "u1", app_model=app_model)
session.add.assert_called_once()
session.flush.assert_called_once()
@ -62,7 +54,7 @@ def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.
def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = model_config_module.ModelConfigResource()
method = _unwrap(api.post)
method = unwrap(api.post)
app_model = SimpleNamespace(
id="app-1",
@ -137,7 +129,7 @@ def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.Mon
monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock)
with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}):
response = method("t1", "u1", app_model=app_model)
response = method(api, "t1", "u1", app_model=app_model)
stored_config = session.add.call_args[0][0]
stored_agent_mode = json.loads(stored_config.agent_mode)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from decimal import Decimal
from inspect import unwrap
from types import SimpleNamespace
import pytest
@ -9,15 +10,6 @@ from werkzeug.exceptions import BadRequest
from controllers.console.app import statistic as statistic_module
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
class _ConnContext:
def __init__(self, rows):
self._rows = rows
@ -48,42 +40,42 @@ def _install_common(monkeypatch: pytest.MonkeyPatch) -> None:
def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-01", message_count=3)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]}
def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 1
@ -94,14 +86,14 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey
def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTerminalsStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)]
_install_common(monkeypatch)
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]}
@ -111,13 +103,13 @@ def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypat
# This just verifies the decorator is applied correctly
# Actual endpoint testing would require complex JOIN mocking
api = statistic_module.AverageSessionInteractionStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
assert callable(method)
def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
def mock_parse(*args, **kwargs):
raise ValueError("Invalid time range")
@ -128,12 +120,12 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
with pytest.raises(BadRequest):
method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", message_count=10),
@ -144,7 +136,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 3
@ -152,20 +144,20 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa
def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyMessageStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
_install_common(monkeypatch)
_install_db(monkeypatch, [])
with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": []}
def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyConversationStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)]
_install_db(monkeypatch, rows)
@ -177,14 +169,14 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M
monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field)
with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]}
def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = statistic_module.DailyTokenCostStatistic()
method = _unwrap(api.get)
method = unwrap(api.get)
rows = [
SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"),
@ -194,7 +186,7 @@ def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.Monk
_install_db(monkeypatch, rows)
with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"):
response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1"))
data = response.get_json()
assert len(data["data"]) == 2

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import inspect
import json
from datetime import datetime
from types import SimpleNamespace
@ -7,6 +8,7 @@ from typing import cast
from unittest.mock import Mock
import pytest
from flask import Flask
from pydantic import ValidationError
from werkzeug.exceptions import HTTPException, NotFound
@ -17,12 +19,6 @@ from graphon.variables import SecretVariable, StringVariable
from graphon.variables.variables import RAGPipelineVariable
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_workflow(**overrides):
workflow = SimpleNamespace(
id="workflow-1",
@ -107,24 +103,20 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None:
build_mock.assert_called_once()
def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_sync_draft_workflow_invalid_content_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
handler = inspect.unwrap(api.post)
with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"):
with pytest.raises(HTTPException) as exc:
handler(api, app_model=SimpleNamespace(id="app"))
handler(api, "t1", app_model=SimpleNamespace(id="app"))
assert exc.value.code == 415
def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_sync_draft_workflow_invalid_json(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
@ -132,19 +124,19 @@ def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch)
data="[]",
content_type="application/json",
):
response, status = handler(api, app_model=SimpleNamespace(id="app"))
response, status = handler(api, "t1", app_model=SimpleNamespace(id="app"))
assert status == 400
assert response["message"] == "Invalid JSON data"
def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_sync_draft_workflow_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="h",
updated_at=None,
created_at=datetime(2024, 1, 1),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
monkeypatch.setattr(
workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env"
)
@ -156,20 +148,19 @@ def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> No
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
method="POST",
json={"graph": {}, "features": {}, "hash": "h"},
):
response = handler(api, app_model=SimpleNamespace(id="app"))
response = handler(api, "t1", app_model=SimpleNamespace(id="app"))
assert response["result"] == "success"
def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
def test_sync_draft_workflow_hash_mismatch(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
def _raise(*_args, **_kwargs):
raise workflow_module.WorkflowHashNotEqualError()
@ -178,7 +169,7 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch)
monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft",
@ -186,10 +177,10 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch)
json={"graph": {}, "features": {}, "hash": "h"},
):
with pytest.raises(DraftWorkflowNotSync):
handler(api, app_model=SimpleNamespace(id="app"))
handler(api, "t1", app_model=SimpleNamespace(id="app"))
def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_restore_published_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(
unique_hash="restored-hash",
updated_at=None,
@ -197,7 +188,6 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
)
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -205,7 +195,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/published-workflow/restore",
@ -213,6 +203,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
):
response = handler(
api,
"t1",
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="published-workflow",
)
@ -221,10 +212,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo
assert response["hash"] == "restored-hash"
def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
def test_restore_published_workflow_to_draft_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -236,7 +224,7 @@ def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/published-workflow/restore",
@ -245,15 +233,15 @@ def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.
with pytest.raises(NotFound):
handler(
api,
"t1",
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="published-workflow",
)
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None:
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
def test_restore_published_workflow_to_draft_returns_400_for_draft_source(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -268,7 +256,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/draft-workflow/restore",
@ -277,6 +265,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m
with pytest.raises(HTTPException) as exc:
handler(
api,
"t1",
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="draft-workflow",
)
@ -286,11 +275,8 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m
def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
app, monkeypatch: pytest.MonkeyPatch
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
user = SimpleNamespace(id="account-1")
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -302,7 +288,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
)
api = workflow_module.DraftWorkflowRestoreApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/workflows/published-workflow/restore",
@ -311,6 +297,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
with pytest.raises(HTTPException) as exc:
handler(
api,
"t1",
app_model=SimpleNamespace(id="app", tenant_id="tenant-1"),
workflow_id="published-workflow",
)
@ -319,9 +306,11 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure(
assert exc.value.description == "invalid workflow graph"
def test_get_published_workflows_serializes_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_get_published_workflows_serializes_items_before_session_closes(
app: Flask, monkeypatch: pytest.MonkeyPatch
) -> None:
api = workflow_module.PublishedAllWorkflowApi()
handler = _unwrap(api.get)
handler = inspect.unwrap(api.get)
session_state = {"open": False}
@ -351,7 +340,6 @@ def test_get_published_workflows_serializes_items_before_session_closes(app, mon
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker())
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
monkeypatch.setattr(
workflow_module,
"WorkflowService",
@ -365,7 +353,7 @@ def test_get_published_workflows_serializes_items_before_session_closes(app, mon
method="GET",
query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"},
):
response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
response = handler(api, "t1", app_model=SimpleNamespace(id="app", workflow_id="wf-1"))
assert response["items"][0]["id"] == "w1"
assert response["page"] == 1
@ -380,7 +368,7 @@ def test_draft_workflow_get_serializes_response_model(monkeypatch: pytest.Monkey
)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.get)
handler = inspect.unwrap(api.get)
response = handler(api, app_model=SimpleNamespace(id="app"))
@ -522,13 +510,13 @@ def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
)
api = workflow_module.DraftWorkflowApi()
handler = _unwrap(api.get)
handler = inspect.unwrap(api.get)
with pytest.raises(DraftWorkflowNotExist):
handler(api, app_model=SimpleNamespace(id="app"))
def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_advanced_chat_run_conversation_not_exists(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
workflow_module.AppGenerateService,
"generate",
@ -536,10 +524,9 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk
workflow_module.services.errors.conversation.ConversationNotExistsError()
),
)
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1"))
api = workflow_module.AdvancedChatDraftWorkflowRunApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/app/advanced-chat/workflows/draft/run",
@ -547,10 +534,10 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk
json={"inputs": {}},
):
with pytest.raises(NotFound):
handler(api, app_model=SimpleNamespace(id="app"))
handler(api, "t1", app_model=SimpleNamespace(id="app"))
def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
app_id_1 = "11111111-1111-1111-1111-111111111111"
app_id_2 = "22222222-2222-2222-2222-222222222222"
signed_avatar_url = "https://files.example.com/signed/avatar-1"
@ -602,7 +589,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p
monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory)
api = workflow_module.WorkflowOnlineUsersApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/workflows/online-users",
@ -636,7 +623,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p
sign_avatar.assert_called_once_with("avatar-file-id")
def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)]
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
monkeypatch.setattr(
@ -653,7 +640,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk
monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory)
api = workflow_module.WorkflowOnlineUsersApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/workflows/online-users",
@ -668,7 +655,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk
assert second_pipeline.hgetall.call_count == 1
def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None:
def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1"))
accessible_app_ids = Mock(return_value=set())
monkeypatch.setattr(
@ -680,7 +667,7 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch:
excessive_ids = [f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS + 1)]
api = workflow_module.WorkflowOnlineUsersApi()
handler = _unwrap(api.post)
handler = inspect.unwrap(api.post)
with app.test_request_context(
"/apps/workflows/online-users",

View File

@ -99,7 +99,7 @@ class MutationResponseCase:
expected_status: int | None = None
def _unwrap_response(result: object) -> tuple[dict[str, object], int | None]:
def unwrap_response(result: object) -> tuple[dict[str, object], int | None]:
if isinstance(result, tuple):
response, status = result
assert isinstance(response, dict)
@ -194,7 +194,7 @@ def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatc
with _patch_payload(payload):
result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123")
response, status = _unwrap_response(result)
response, status = unwrap_response(result)
assert response["id"] == "comment-1"
assert status == 201
create_comment_mock.assert_called_once_with(
@ -224,7 +224,7 @@ def test_update_comment_omits_mentions_when_payload_does_not_include_them(
with _patch_payload(payload):
result = workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1")
response, status = _unwrap_response(result)
response, status = unwrap_response(result)
assert response == {"id": "comment-1", "updated_at": JAN_1_2024_NOON_TS}
assert status is None
update_comment_mock.assert_called_once_with(
@ -480,7 +480,7 @@ def test_mutation_endpoints_serialize_response_models(
with _patch_payload(case.payload):
result = getattr(case.resource_cls(), case.method_name)(**case.kwargs)
response, status = _unwrap_response(result)
response, status = unwrap_response(result)
assert response == case.expected_response
assert status == case.expected_status

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from typing import Any
@ -12,12 +13,6 @@ from controllers.console.app import workflow_run as workflow_run_module
from models import Account
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _serialize_200_response(handler, payload: Any) -> Any:
response_doc = getattr(handler, "__apidoc__", {}).get("responses", {}).get("200")
if response_doc is None:
@ -100,7 +95,7 @@ def test_workflow_run_list_returns_frontend_history_contract(app: Flask, monkeyp
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
api = workflow_run_module.WorkflowRunListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/apps/app-1/workflow-runs?limit=10", method="GET"):
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"))
@ -141,7 +136,7 @@ def test_advanced_chat_workflow_run_list_keeps_message_fields(app: Flask, monkey
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
api = workflow_run_module.AdvancedChatAppWorkflowRunListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/apps/app-1/advanced-chat/workflow-runs?limit=1", method="GET"):
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"))
@ -180,7 +175,7 @@ def test_workflow_run_detail_returns_frontend_detail_contract(app: Flask, monkey
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
api = workflow_run_module.WorkflowRunDetailApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/apps/app-1/workflow-runs/run-1", method="GET"):
payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1")
@ -217,7 +212,7 @@ def test_workflow_run_node_executions_return_frontend_trace_contract(
monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService)
api = workflow_run_module.WorkflowRunNodeExecutionListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"):
payload = handler(

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
@ -12,12 +13,6 @@ from controllers.console.auth.data_source_bearer_auth import (
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _payload_patch(payload: dict):
return patch.object(
type(console_ns),
@ -29,7 +24,7 @@ def _payload_patch(payload: dict):
def test_list_data_source_auth_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSource()
method = _unwrap(api.get)
method = unwrap(api.get)
binding = SimpleNamespace(
id="binding-1",
category="api_key",
@ -52,7 +47,7 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None:
def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBinding()
method = _unwrap(api.post)
method = unwrap(api.post)
payload = {
"category": "api_key",
"provider": "custom",
@ -73,7 +68,7 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBindingDelete()
method = _unwrap(api.delete)
method = unwrap(api.delete)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"

View File

@ -41,13 +41,7 @@ def encode_code(code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode()
def _unwrap(func):
bound_self = getattr(func, "__self__", None)
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
if bound_self is not None:
return func.__get__(bound_self, bound_self.__class__)
return func
from inspect import unwrap
class TestLoginApi:
@ -510,7 +504,7 @@ class TestLogoutApi:
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = _unwrap(logout_api.post)(mock_account)
response = unwrap(logout_api.post)(logout_api, mock_account)
# Assert
mock_service_logout.assert_called_once_with(account=mock_account)
@ -536,7 +530,7 @@ class TestLogoutApi:
# Act
with app.test_request_context("/logout", method="POST"):
logout_api = LogoutApi()
response = _unwrap(logout_api.post)(anonymous_user)
response = unwrap(logout_api.post)(logout_api, anonymous_user)
# Assert
assert response.json["result"] == "success"

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from inspect import unwrap
from unittest.mock import patch
from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi
@ -8,12 +9,6 @@ from models.account import AccountStatus, TenantAccountRole
from models.model import OAuthProviderApp
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_account() -> Account:
account = Account(
name="Test User",
@ -38,7 +33,7 @@ def _make_oauth_provider_app() -> OAuthProviderApp:
def test_oauth_authorize_uses_injected_current_user() -> None:
api = OAuthServerUserAuthorizeApi()
method = _unwrap(api.post)
method = unwrap(api.post)
account = _make_account()
oauth_provider_app = _make_oauth_provider_app()

View File

@ -49,7 +49,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app: Flask, mock_token_pair):
"""
Test successful token refresh flow.
@ -170,7 +170,7 @@ class TestRefreshTokenApi:
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app: Flask, mock_token_pair):
"""
Test that token refresh updates all three tokens.

View File

@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any, cast
from inspect import unwrap
from unittest.mock import PropertyMock, patch
import pytest
@ -20,12 +19,6 @@ from models.dataset import PipelineCustomizedTemplate
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
def _unwrap(func: object) -> Callable[..., Any]:
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return cast(Callable[..., Any], func)
def _template_item() -> dict[str, object]:
return {
"id": "template-1",
@ -60,7 +53,7 @@ def _payload() -> dict[str, object]:
class TestPipelineTemplateListApi:
def test_get_uses_query_defaults_and_serializes_nullable_fields(self, app: Flask) -> None:
api = PipelineTemplateListApi()
method = _unwrap(api.get)
method = unwrap(api.get)
service_calls: list[tuple[str, str]] = []
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
@ -87,7 +80,7 @@ class TestPipelineTemplateListApi:
def test_get_passes_explicit_query_to_service(self, app: Flask) -> None:
api = PipelineTemplateListApi()
method = _unwrap(api.get)
method = unwrap(api.get)
service_calls: list[tuple[str, str]] = []
def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]:
@ -108,7 +101,7 @@ class TestPipelineTemplateListApi:
class TestPipelineTemplateDetailApi:
def test_get_serializes_template_detail(self, app: Flask) -> None:
api = PipelineTemplateDetailApi()
method = _unwrap(api.get)
method = unwrap(api.get)
service_calls: list[tuple[str, str]] = []
class Service:
@ -128,7 +121,7 @@ class TestPipelineTemplateDetailApi:
def test_get_raises_not_found_without_custom_response_body(self, app: Flask) -> None:
api = PipelineTemplateDetailApi()
method = _unwrap(api.get)
method = unwrap(api.get)
class Service:
def get_pipeline_template_detail(self, template_id: str, template_type: str) -> None:
@ -145,7 +138,7 @@ class TestPipelineTemplateDetailApi:
class TestCustomizedPipelineTemplateApi:
def test_patch_validates_payload_and_returns_empty_204(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = _unwrap(api.patch)
method = unwrap(api.patch)
payload = _payload()
service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = []
@ -174,7 +167,7 @@ class TestCustomizedPipelineTemplateApi:
def test_patch_defaults_missing_icon_info_before_service_call(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = _unwrap(api.patch)
method = unwrap(api.patch)
payload: dict[str, object] = {
"name": "Updated template",
"description": "Updated description",
@ -204,7 +197,7 @@ class TestCustomizedPipelineTemplateApi:
def test_delete_returns_empty_204(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = _unwrap(api.delete)
method = unwrap(api.delete)
deleted_template_ids: list[str] = []
def delete_template(template_id: str) -> None:
@ -221,7 +214,7 @@ class TestCustomizedPipelineTemplateApi:
def test_post_exports_yaml_from_orm_template(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
template = PipelineCustomizedTemplate(
tenant_id="00000000-0000-0000-0000-000000000001",
name="Template",
@ -265,7 +258,7 @@ class TestCustomizedPipelineTemplateApi:
def test_post_raises_when_template_is_missing(self, app: Flask) -> None:
api = CustomizedPipelineTemplateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
class Session:
def scalar(self, stmt: object) -> None:
@ -297,7 +290,7 @@ class TestCustomizedPipelineTemplateApi:
class TestPublishCustomizedPipelineTemplateApi:
def test_post_validates_payload_and_returns_empty_204(self, app: Flask) -> None:
api = PublishCustomizedPipelineTemplateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
payload = _payload()
service_calls: list[tuple[str, dict[str, object]]] = []
@ -317,7 +310,7 @@ class TestPublishCustomizedPipelineTemplateApi:
def test_post_allows_missing_icon_info_for_publish_service_fallback(self, app: Flask) -> None:
api = PublishCustomizedPipelineTemplateApi()
method = _unwrap(api.post)
method = unwrap(api.post)
payload: dict[str, object] = {
"name": "Published template",
"description": "Description",

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from datetime import datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
@ -9,12 +10,6 @@ import pytest
from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_workflow(**overrides):
workflow = SimpleNamespace(
id="workflow-1",
@ -45,7 +40,7 @@ def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch:
)
api = module.DraftRagPipelineApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
response = handler(api, pipeline=SimpleNamespace(id="pipeline-1"))
@ -63,7 +58,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
api = module.PublishedAllRagPipelineApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
session_state = {"open": False}
class _SessionContext:
@ -133,7 +128,7 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch:
payload: dict[str, object] = {"marked_name": "Updated release"}
api = module.RagPipelineByIdApi()
handler = _unwrap(api.patch)
handler = unwrap(api.patch)
with (
app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload),

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import json
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -23,12 +24,6 @@ from models.human_input import RecipientType
from models.model import AppMode
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def test_jsonify_form_definition() -> None:
expiration = datetime(2024, 1, 1, tzinfo=UTC)
definition = SimpleNamespace(model_dump=lambda: {"fields": []})
@ -64,7 +59,7 @@ def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
response = handler(api, "tenant-1", form_token="token")
@ -85,7 +80,7 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/console/api/form/human_input/token", method="GET"):
with pytest.raises(NotFoundError):
@ -106,7 +101,7 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
@ -137,7 +132,7 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
@ -172,7 +167,7 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleHumanInputFormApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context(
"/console/api/form/human_input/token",
@ -244,7 +239,7 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch)
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
@ -271,7 +266,7 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
@ -298,7 +293,7 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
with pytest.raises(NotFoundError):
@ -342,7 +337,7 @@ def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) -
monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object()))
api = ConsoleWorkflowEventsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
with app.test_request_context("/console/api/workflow/run/events", method="GET"):
response = handler(api, "t1", SimpleNamespace(id="user-1"), workflow_run_id="run-1")

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import urllib.parse
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import MagicMock
@ -16,12 +17,6 @@ from services.errors.file import FileTooLargeError as ServiceFileTooLargeError
from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _make_account(account_id: str = "u1") -> Account:
account = Account(
name="Test User",
@ -89,7 +84,7 @@ def _mock_upload_dependencies(
def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
handler = unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
@ -110,7 +105,7 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest
def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
handler = unwrap(api.get)
target_url = "http://example.com/api/aiagent/httpview/txt"
query = "fileNameKey=cankao1_ce4305bc-be20-4c5d-8732-de1741d28e27"
@ -131,7 +126,7 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch:
def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.GetRemoteFileInfo()
handler = _unwrap(api.get)
handler = unwrap(api.get)
decoded_url = "https://example.com/test.txt"
encoded_url = urllib.parse.quote(decoded_url, safe="")
@ -154,7 +149,7 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo
def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/report.txt"
get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content")
@ -196,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
app, monkeypatch: pytest.MonkeyPatch
) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/photo.jpg"
head_resp = _FakeResponse(status_code=200, method="HEAD", content=b"head-content")
@ -227,7 +222,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds(
def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/fail.txt"
make_request = MagicMock(
@ -245,7 +240,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat
def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/fail.txt"
request = httpx.Request("HEAD", url)
@ -259,7 +254,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte
def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/large.bin"
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
@ -274,7 +269,7 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk
def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/large.bin"
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))
@ -289,7 +284,7 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp
def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None:
api = remote_files_module.RemoteFileUpload()
handler = _unwrap(api.post)
handler = unwrap(api.post)
url = "https://example.com/file.exe"
make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload"))

View File

@ -10,6 +10,7 @@ from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from flask import Flask
from werkzeug.exceptions import NotFound
from controllers.openapi.auth.data import AuthData
@ -30,7 +31,7 @@ def _make_auth_data(app_model, caller, caller_kind):
class TestOpenApiHumanInputFormGet:
def test_get_success(self, app, bypass_pipeline, monkeypatch):
def test_get_success(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
definition = SimpleNamespace(
@ -74,7 +75,7 @@ class TestOpenApiHumanInputFormGet:
assert payload["user_actions"] == [{"id": "submit", "title": "Submit"}]
service_mock.ensure_form_active.assert_called_once_with(form)
def test_get_form_not_found(self, app, bypass_pipeline, monkeypatch):
def test_get_form_not_found(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
service_mock = Mock()
@ -96,7 +97,7 @@ class TestOpenApiHumanInputFormGet:
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
def test_get_form_wrong_app(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
form = SimpleNamespace(
@ -121,7 +122,7 @@ class TestOpenApiHumanInputFormGet:
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
def test_get_form_wrong_surface(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
form = SimpleNamespace(
@ -159,7 +160,7 @@ class TestOpenApiHumanInputFormPost:
expiration_time=datetime(2099, 1, 1, tzinfo=UTC),
)
def test_post_account_caller_uses_user_id(self, app, bypass_pipeline, monkeypatch):
def test_post_account_caller_uses_user_id(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
form = self._make_form()
@ -196,7 +197,7 @@ class TestOpenApiHumanInputFormPost:
)
assert result == ({}, 200)
def test_post_end_user_caller_uses_end_user_id(self, app, bypass_pipeline, monkeypatch):
def test_post_end_user_caller_uses_end_user_id(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
form = self._make_form()

View File

@ -13,6 +13,7 @@ Note: API endpoint tests for annotation controllers are complex due to:
"""
import uuid
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -34,13 +35,6 @@ from extensions.ext_redis import redis_client
from models.model import App
from services.annotation_service import AppAnnotationService
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
# ---------------------------------------------------------------------------
# Pydantic Model Tests
# ---------------------------------------------------------------------------
@ -193,7 +187,7 @@ class TestAnnotationReplyActionApi:
monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock)
api = AnnotationReplyActionApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context(
@ -211,7 +205,7 @@ class TestAnnotationReplyActionApi:
monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock)
api = AnnotationReplyActionApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context(
@ -230,7 +224,7 @@ class TestAnnotationReplyActionStatusApi:
monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None)
api = AnnotationReplyActionStatusApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app")
with pytest.raises(ValueError):
@ -245,7 +239,7 @@ class TestAnnotationReplyActionStatusApi:
monkeypatch.setattr(redis_client, "get", _get)
api = AnnotationReplyActionStatusApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app")
response, status = handler(api, app_model=app_model, job_id="j1", action="enable")
@ -262,7 +256,7 @@ class TestAnnotationListApi:
monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock)
api = AnnotationListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations", method="GET"):
@ -278,7 +272,7 @@ class TestAnnotationListApi:
monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock)
api = AnnotationListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations?page=2&limit=5&keyword=refund", method="GET"):
@ -297,7 +291,7 @@ class TestAnnotationListApi:
monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock)
api = AnnotationListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app")
with app.test_request_context(f"/apps/annotations?{query_string}", method="GET"):
@ -315,7 +309,7 @@ class TestAnnotationListApi:
)
api = AnnotationListApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}):
@ -337,8 +331,8 @@ class TestAnnotationUpdateDeleteApi:
monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock)
api = AnnotationUpdateDeleteApi()
put_handler = _unwrap(api.put)
delete_handler = _unwrap(api.delete)
put_handler = unwrap(api.put)
delete_handler = unwrap(api.delete)
app_model = SimpleNamespace(id="app")
with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}):

View File

@ -9,6 +9,7 @@ Tests coverage for:
import io
import uuid
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -41,12 +42,6 @@ from services.errors.audio import (
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
def _file_data():
return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav")
@ -194,7 +189,7 @@ class TestAudioApi:
def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"})
api = AudioApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
@ -220,7 +215,7 @@ class TestAudioApi:
def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None:
monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc))
api = AudioApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
@ -233,7 +228,7 @@ class TestAudioApi:
AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))
)
api = AudioApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(id="u1")
@ -247,7 +242,7 @@ class TestTextApi:
monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"})
api = TextApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")
@ -266,7 +261,7 @@ class TestTextApi:
)
api = TextApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="a1")
end_user = SimpleNamespace(external_user_id="ext")

View File

@ -12,6 +12,7 @@ Focus on:
"""
import uuid
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -44,12 +45,6 @@ from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestCompletionRequestPayload:
"""Test suite for CompletionRequestPayload Pydantic model."""
@ -426,7 +421,7 @@ class TestChatRequestPayloadController:
class TestCompletionApiController:
def test_wrong_mode(self, app: Flask) -> None:
api = CompletionApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -444,7 +439,7 @@ class TestCompletionApiController:
end_user = SimpleNamespace()
api = CompletionApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}):
with pytest.raises(NotFound):
@ -454,7 +449,7 @@ class TestCompletionApiController:
class TestCompletionStopApiController:
def test_wrong_mode(self, app: Flask) -> None:
api = CompletionStopApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="u1")
@ -467,7 +462,7 @@ class TestCompletionStopApiController:
monkeypatch.setattr(AppTaskService, "stop_task", stop_mock)
api = CompletionStopApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace(id="u1")
@ -481,7 +476,7 @@ class TestCompletionStopApiController:
class TestChatApiController:
def test_wrong_mode(self, app: Flask) -> None:
api = ChatApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
@ -497,7 +492,7 @@ class TestChatApiController:
)
api = ChatApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -513,7 +508,7 @@ class TestChatApiController:
)
api = ChatApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -525,7 +520,7 @@ class TestChatApiController:
class TestChatStopApiController:
def test_wrong_mode(self, app: Flask) -> None:
api = ChatStopApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace(id="u1")

View File

@ -16,6 +16,7 @@ Focus on:
import sys
import uuid
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -51,12 +52,6 @@ from services.errors.conversation import (
)
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestConversationListQuery:
"""Test suite for ConversationListQuery Pydantic model."""
@ -380,7 +375,7 @@ class TestConversationAppModeValidation:
app raises NotChatAppError.
"""
app = Mock(spec=App)
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
app_mode = AppMode.value_of(app.mode)
assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}
@ -498,7 +493,7 @@ class TestConversationPayloadsController:
class TestConversationApiController:
def test_list_not_chat(self, app: Flask) -> None:
api = ConversationApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
@ -531,7 +526,7 @@ class TestConversationApiController:
monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub)
api = ConversationApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -546,7 +541,7 @@ class TestConversationApiController:
class TestConversationDetailApiController:
def test_delete_not_chat(self, app: Flask) -> None:
api = ConversationDetailApi()
handler = _unwrap(api.delete)
handler = unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
@ -562,7 +557,7 @@ class TestConversationDetailApiController:
)
api = ConversationDetailApi()
handler = _unwrap(api.delete)
handler = unwrap(api.delete)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -580,7 +575,7 @@ class TestConversationRenameApiController:
)
api = ConversationRenameApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -596,7 +591,7 @@ class TestConversationRenameApiController:
class TestConversationVariablesApiController:
def test_not_chat(self, app: Flask) -> None:
api = ConversationVariablesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION)
end_user = SimpleNamespace()
@ -612,7 +607,7 @@ class TestConversationVariablesApiController:
)
api = ConversationVariablesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -645,7 +640,7 @@ class TestConversationVariablesApiController:
)
api = ConversationVariablesApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -671,7 +666,7 @@ class TestConversationVariableDetailApiController:
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
handler = unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -697,7 +692,7 @@ class TestConversationVariableDetailApiController:
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
handler = unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()
@ -731,7 +726,7 @@ class TestConversationVariableDetailApiController:
)
api = ConversationVariableDetailApi()
handler = _unwrap(api.put)
handler = unwrap(api.put)
app_model = SimpleNamespace(mode=AppMode.CHAT)
end_user = SimpleNamespace()

View File

@ -203,7 +203,7 @@ class TestFileUploadResponse:
# unwrapped method directly to bypass the decorator.
# =============================================================================
from tests.unit_tests.controllers.service_api.conftest import _unwrap
from inspect import unwrap
@pytest.fixture
@ -274,7 +274,7 @@ class TestFileApiPost:
data=data,
):
api = FileApi()
response, status = _unwrap(api.post)(
response, status = unwrap(api.post)(
api,
app_model=mock_app_model,
end_user=mock_end_user,
@ -295,7 +295,7 @@ class TestFileApiPost:
):
api = FileApi()
with pytest.raises(NoFileUploadedError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user):
"""Test TooManyFilesError when multiple files uploaded."""
@ -316,7 +316,7 @@ class TestFileApiPost:
):
api = FileApi()
with pytest.raises(TooManyFilesError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user):
"""Test UnsupportedFileTypeError when file has no mimetype."""
@ -334,7 +334,7 @@ class TestFileApiPost:
):
api = FileApi()
with pytest.raises(UnsupportedFileTypeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
@ -366,7 +366,7 @@ class TestFileApiPost:
):
api = FileApi()
with pytest.raises(FileTooLargeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
@patch("controllers.service_api.app.file.FileService")
@patch("controllers.service_api.app.file.db")
@ -396,4 +396,4 @@ class TestFileApiPost:
):
api = FileApi()
with pytest.raises(UnsupportedFileTypeError):
_unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)
unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user)

View File

@ -7,6 +7,7 @@ import sys
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from typing import override
from unittest.mock import ANY, MagicMock, Mock
@ -44,7 +45,6 @@ from repositories.api_workflow_node_execution_repository import WorkflowNodeExec
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.app_generate_service import AppGenerateService
from services.workflow_event_snapshot_service import _build_snapshot_events
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class _DummyRateLimit:
@ -275,8 +275,8 @@ class TestHitlServiceApi:
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"):
@ -310,8 +310,8 @@ class TestHitlServiceApi:
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context(

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -15,7 +16,6 @@ from werkzeug.exceptions import NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload
from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi
from models.human_input import RecipientType
from tests.unit_tests.controllers.service_api.conftest import _unwrap
class TestWorkflowHumanInputFormApi:
@ -45,7 +45,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
@ -98,7 +98,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
@ -121,7 +121,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
@ -153,7 +153,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
with app.test_request_context("/form/human_input/token-1", method="GET"):
@ -175,7 +175,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
@ -209,7 +209,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")
inputs = {
@ -285,7 +285,7 @@ class TestWorkflowHumanInputFormApi:
monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object()))
api = WorkflowHumanInputFormApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
end_user = SimpleNamespace(id="end-user-1")

View File

@ -15,6 +15,7 @@ Focus on:
"""
import uuid
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -43,12 +44,6 @@ from services.errors.message import (
from services.message_service import MessageService
def _unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestMessageListQuery:
"""Test suite for MessageListQuery Pydantic model."""
@ -383,7 +378,7 @@ class TestMessageService:
class TestMessageListApi:
def test_not_chat_app(self, app: Flask) -> None:
api = MessageListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
@ -399,7 +394,7 @@ class TestMessageListApi:
)
api = MessageListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -418,7 +413,7 @@ class TestMessageListApi:
)
api = MessageListApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -439,7 +434,7 @@ class TestMessageFeedbackApi:
)
api = MessageFeedbackApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace()
end_user = SimpleNamespace()
@ -457,7 +452,7 @@ class TestAppGetFeedbacksApi:
monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"])
api = AppGetFeedbacksApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace()
with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"):
@ -469,7 +464,7 @@ class TestAppGetFeedbacksApi:
class TestMessageSuggestedApi:
def test_not_chat(self, app: Flask) -> None:
api = MessageSuggestedApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.COMPLETION.value)
end_user = SimpleNamespace()
@ -485,7 +480,7 @@ class TestMessageSuggestedApi:
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -501,7 +496,7 @@ class TestMessageSuggestedApi:
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -517,7 +512,7 @@ class TestMessageSuggestedApi:
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -533,7 +528,7 @@ class TestMessageSuggestedApi:
)
api = MessageSuggestedApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()

View File

@ -16,6 +16,7 @@ Focus on:
import sys
import uuid
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock, patch
@ -369,7 +370,7 @@ class TestWorkflowRunRepository:
class TestWorkflowRunDetailApi:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
with app.test_request_context("/workflows/run/1", method="GET"):
@ -388,8 +389,8 @@ class TestWorkflowRunDetailApi:
)
api = WorkflowRunDetailApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1")
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW, tenant_id="t1", id="a1")
result = handler(api, app_model=app_model, workflow_run_id="run")
assert result["id"] == "run"
@ -400,7 +401,7 @@ class TestWorkflowRunDetailApi:
class TestWorkflowRunApi:
def test_not_workflow_app(self, app: Flask) -> None:
api = WorkflowRunApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -416,8 +417,8 @@ class TestWorkflowRunApi:
)
api = WorkflowRunApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}):
@ -434,8 +435,8 @@ class TestWorkflowRunByIdApi:
)
api = WorkflowRunByIdApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
@ -450,8 +451,8 @@ class TestWorkflowRunByIdApi:
)
api = WorkflowRunByIdApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
end_user = SimpleNamespace()
with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}):
@ -462,7 +463,7 @@ class TestWorkflowRunByIdApi:
class TestWorkflowTaskStopApi:
def test_wrong_mode(self, app: Flask) -> None:
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace()
@ -477,8 +478,8 @@ class TestWorkflowTaskStopApi:
monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock)
api = WorkflowTaskStopApi()
handler = _unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value)
handler = unwrap(api.post)
app_model = SimpleNamespace(mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="u1")
with app.test_request_context("/workflows/tasks/1/stop", method="POST"):
@ -515,7 +516,7 @@ class TestWorkflowAppLogApi:
)
api = WorkflowAppLogApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="a1")
with app.test_request_context("/workflows/logs", method="GET"):
@ -533,15 +534,13 @@ class TestWorkflowAppLogApi:
# directly to bypass the decorator.
# =============================================================================
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
def mock_workflow_app():
app = Mock(spec=App)
app.id = str(uuid.uuid4())
app.tenant_id = str(uuid.uuid4())
app.mode = AppMode.WORKFLOW.value
app.mode = AppMode.WORKFLOW
return app
@ -574,7 +573,7 @@ class TestWorkflowRunDetailApiGet:
method="GET",
):
api = WorkflowRunDetailApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
result = unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id)
assert result["id"] == mock_run.id
assert result["status"] == "succeeded"
@ -590,7 +589,7 @@ class TestWorkflowRunDetailApiGet:
with app.test_request_context("/workflows/run/run-1", method="GET"):
api = WorkflowRunDetailApi()
with pytest.raises(NotWorkflowAppError):
_unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1")
class TestWorkflowTaskStopApiPost:
@ -613,7 +612,7 @@ class TestWorkflowTaskStopApiPost:
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
api = WorkflowTaskStopApi()
result = _unwrap(api.post)(
result = unwrap(api.post)(
api,
app_model=mock_workflow_app,
end_user=Mock(),
@ -635,7 +634,7 @@ class TestWorkflowTaskStopApiPost:
with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"):
api = WorkflowTaskStopApi()
with pytest.raises(NotWorkflowAppError):
_unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1")
class TestWorkflowAppLogApiGet:
@ -681,6 +680,6 @@ class TestWorkflowAppLogApiGet:
):
with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory):
api = WorkflowAppLogApi()
result = _unwrap(api.get)(api, app_model=mock_workflow_app)
result = unwrap(api.get)(api, app_model=mock_workflow_app)
assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []}

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import json
import sys
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import Mock
@ -16,7 +17,6 @@ from controllers.service_api.app.error import NotWorkflowAppError
from controllers.service_api.app.workflow_events import WorkflowEventsApi
from models.enums import CreatorUserRole
from models.model import AppMode
from tests.unit_tests.controllers.service_api.conftest import _unwrap
def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
@ -34,7 +34,7 @@ def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run):
class TestWorkflowEventsApi:
def test_wrong_app_mode(self, app: Flask) -> None:
api = WorkflowEventsApi()
handler = _unwrap(api.get)
handler = unwrap(api.get)
app_model = SimpleNamespace(mode=AppMode.CHAT.value)
end_user = SimpleNamespace(id="end-user-1")
@ -45,8 +45,8 @@ class TestWorkflowEventsApi:
def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None:
_mock_repo_for_run(monkeypatch, workflow_run=None)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
@ -63,8 +63,8 @@ class TestWorkflowEventsApi:
)
_mock_repo_for_run(monkeypatch, workflow_run=workflow_run)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
@ -90,8 +90,8 @@ class TestWorkflowEventsApi:
)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
@ -121,8 +121,8 @@ class TestWorkflowEventsApi:
monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"):
@ -154,8 +154,8 @@ class TestWorkflowEventsApi:
monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder)
api = WorkflowEventsApi()
handler = _unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
handler = unwrap(api.get)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
end_user = SimpleNamespace(id="end-user-1")
with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"):

View File

@ -204,11 +204,3 @@ def mock_child_chunk():
child_chunk.tenant_id = str(uuid.uuid4())
child_chunk.content = "Test child chunk content"
return child_chunk
def _unwrap(method):
"""Walk ``__wrapped__`` chain to get the original function."""
fn = method
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return fn

View File

@ -16,6 +16,7 @@ Decorator strategy:
"""
import uuid
from inspect import unwrap
from unittest.mock import Mock, patch
import pytest
@ -29,7 +30,6 @@ from controllers.service_api.dataset.metadata import (
DatasetMetadataServiceApi,
DocumentMetadataEditServiceApi,
)
from tests.unit_tests.controllers.service_api.conftest import _unwrap
@pytest.fixture
@ -65,7 +65,7 @@ class TestDatasetMetadataCreatePost:
@staticmethod
def _call_post(api, **kwargs):
return _unwrap(api.post)(api, **kwargs)
return unwrap(api.post)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@ -195,7 +195,7 @@ class TestDatasetMetadataServiceApiPatch:
@staticmethod
def _call_patch(api, **kwargs):
return _unwrap(api.patch)(api, **kwargs)
return unwrap(api.patch)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@ -267,7 +267,7 @@ class TestDatasetMetadataServiceApiDelete:
@staticmethod
def _call_delete(api, **kwargs):
return _unwrap(api.delete)(api, **kwargs)
return unwrap(api.delete)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@ -376,7 +376,7 @@ class TestDatasetMetadataBuiltInFieldAction:
@staticmethod
def _call_post(api, **kwargs):
return _unwrap(api.post)(api, **kwargs)
return unwrap(api.post)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")
@ -479,7 +479,7 @@ class TestDocumentMetadataEditPost:
@staticmethod
def _call_post(api, **kwargs):
return _unwrap(api.post)(api, **kwargs)
return unwrap(api.post)(api, **kwargs)
@patch("controllers.service_api.dataset.metadata.MetadataService")
@patch("controllers.service_api.dataset.metadata.DatasetService")

View File

@ -7,7 +7,7 @@ from models.model import AppMode
class TestWorkflowAppConfigManager:
def test_get_app_config(self):
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value)
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
workflow = SimpleNamespace(id="wf-1", features_dict={})
with (

View File

@ -383,7 +383,7 @@ class TestWorkflowService:
assert result == mock_workflow
def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service):
def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service: WorkflowService):
"""Test get_draft_workflow passes an injected session to published workflow lookup."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
workflow_id = "workflow-123"
@ -458,7 +458,7 @@ class TestWorkflowService:
assert result == mock_workflow
def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service):
def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service: WorkflowService):
"""Test get_published_workflow returns None when app has no workflow_id."""
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
@ -658,21 +658,21 @@ class TestWorkflowService:
# ==================== Workflow Validation Tests ====================
# These tests verify graph structure and feature configuration validation
def test_validate_graph_structure_empty_graph(self, workflow_service):
def test_validate_graph_structure_empty_graph(self, workflow_service: WorkflowService):
"""Test validate_graph_structure accepts empty graph."""
graph = {"nodes": []}
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_valid_graph(self, workflow_service):
def test_validate_graph_structure_valid_graph(self, workflow_service: WorkflowService):
"""Test validate_graph_structure accepts valid graph."""
graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()
# Should not raise any exception
workflow_service.validate_graph_structure(graph)
def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service):
def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service: WorkflowService):
"""
Test validate_graph_structure raises error when start and trigger nodes coexist.
@ -707,7 +707,7 @@ class TestWorkflowService:
with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"):
workflow_service.validate_graph_structure(graph)
def test_validate_features_structure_workflow_mode(self, workflow_service):
def test_validate_features_structure_workflow_mode(self, workflow_service: WorkflowService):
"""
Test validate_features_structure for workflow mode.
@ -723,7 +723,7 @@ class TestWorkflowService:
tenant_id=app.tenant_id, config=features, only_structure_validate=True
)
def test_validate_features_structure_advanced_chat_mode(self, workflow_service):
def test_validate_features_structure_advanced_chat_mode(self, workflow_service: WorkflowService):
"""Test validate_features_structure for advanced chat mode."""
app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.ADVANCED_CHAT)
features = {"opening_statement": "Hello"}
@ -734,7 +734,7 @@ class TestWorkflowService:
tenant_id=app.tenant_id, config=features, only_structure_validate=True
)
def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service):
def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service: WorkflowService):
"""Test validate_features_structure raises error for invalid mode."""
app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION)
features = {}
@ -767,7 +767,7 @@ class TestWorkflowService:
assert workflow.updated_at == "now"
mock_db_session.session.commit.assert_called_once()
def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service):
def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service: WorkflowService):
"""Test update_draft_workflow_environment_variables raises when draft missing."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
account = TestWorkflowAssociatedDataFactory.create_account_mock()
@ -802,7 +802,7 @@ class TestWorkflowService:
assert workflow.updated_at == "now"
mock_db_session.session.commit.assert_called_once()
def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service):
def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service: WorkflowService):
"""Test update_draft_workflow_conversation_variables raises when draft missing."""
app = TestWorkflowAssociatedDataFactory.create_app_mock()
account = TestWorkflowAssociatedDataFactory.create_account_mock()
@ -917,7 +917,7 @@ class TestWorkflowService:
# ==================== Version Management Tests ====================
# These tests verify listing and managing published workflow versions
def test_get_all_published_workflow_with_pagination(self, workflow_service):
def test_get_all_published_workflow_with_pagination(self, workflow_service: WorkflowService):
"""
Test get_all_published_workflow returns paginated results.
@ -950,7 +950,7 @@ class TestWorkflowService:
assert len(workflows) == 5
assert has_more is False
def test_get_all_published_workflow_has_more(self, workflow_service):
def test_get_all_published_workflow_has_more(self, workflow_service: WorkflowService):
"""
Test get_all_published_workflow indicates has_more when results exceed limit.
@ -983,7 +983,7 @@ class TestWorkflowService:
assert len(workflows) == 10
assert has_more is True
def test_get_all_published_workflow_no_workflow_id(self, workflow_service):
def test_get_all_published_workflow_no_workflow_id(self, workflow_service: WorkflowService):
"""Test get_all_published_workflow returns empty when app has no workflow_id."""
app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None)
mock_session = MagicMock()
@ -998,7 +998,7 @@ class TestWorkflowService:
# ==================== Update Workflow Tests ====================
# These tests verify updating workflow metadata (name, comments, etc.)
def test_update_workflow_success(self, workflow_service):
def test_update_workflow_success(self, workflow_service: WorkflowService):
"""
Test update_workflow updates workflow attributes.
@ -1032,7 +1032,7 @@ class TestWorkflowService:
assert mock_workflow.marked_comment == "Updated Comment"
assert mock_workflow.updated_by == account_id
def test_update_workflow_not_found(self, workflow_service):
def test_update_workflow_not_found(self, workflow_service: WorkflowService):
"""Test update_workflow returns None when workflow not found."""
mock_session = MagicMock()
mock_session.scalar.return_value = None
@ -1055,7 +1055,7 @@ class TestWorkflowService:
# ==================== Delete Workflow Tests ====================
# These tests verify workflow deletion with safety checks
def test_delete_workflow_success(self, workflow_service):
def test_delete_workflow_success(self, workflow_service: WorkflowService):
"""
Test delete_workflow successfully deletes a published workflow.
@ -1085,7 +1085,7 @@ class TestWorkflowService:
assert result is True
mock_session.delete.assert_called_once_with(mock_workflow)
def test_delete_workflow_draft_raises_error(self, workflow_service):
def test_delete_workflow_draft_raises_error(self, workflow_service: WorkflowService):
"""
Test delete_workflow raises error when trying to delete draft.
@ -1109,7 +1109,7 @@ class TestWorkflowService:
with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow"):
workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service):
def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service: WorkflowService):
"""
Test delete_workflow raises error when workflow is in use by app.
@ -1132,7 +1132,7 @@ class TestWorkflowService:
with pytest.raises(WorkflowInUseError, match="currently in use by app"):
workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
def test_delete_workflow_published_as_tool_raises_error(self, workflow_service):
def test_delete_workflow_published_as_tool_raises_error(self, workflow_service: WorkflowService):
"""
Test delete_workflow raises error when workflow is published as tool.
@ -1156,7 +1156,7 @@ class TestWorkflowService:
with pytest.raises(WorkflowInUseError, match="published as a tool"):
workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id)
def test_delete_workflow_not_found_raises_error(self, workflow_service):
def test_delete_workflow_not_found_raises_error(self, workflow_service: WorkflowService):
"""Test delete_workflow raises error when workflow not found."""
workflow_id = "nonexistent"
tenant_id = "tenant-456"
@ -1175,7 +1175,7 @@ class TestWorkflowService:
# ==================== Get Default Block Config Tests ====================
# These tests verify retrieval of default node configurations
def test_get_default_block_configs(self, workflow_service):
def test_get_default_block_configs(self, workflow_service: WorkflowService):
"""
Test get_default_block_configs returns list of default configs.
@ -1195,7 +1195,7 @@ class TestWorkflowService:
assert len(result) > 0
def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service):
def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service: WorkflowService):
injected_config = HttpRequestNodeConfig(
max_connect_timeout=15,
max_read_timeout=25,
@ -1234,7 +1234,7 @@ class TestWorkflowService:
assert passed_http_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
mock_llm_node_class.get_default_config.assert_called_once_with(filters=None)
def test_get_default_block_config_for_node_type(self, workflow_service):
def test_get_default_block_config_for_node_type(self, workflow_service: WorkflowService):
"""
Test get_default_block_config returns config for specific node type.
@ -1258,7 +1258,7 @@ class TestWorkflowService:
assert result == mock_config
mock_node_class.get_default_config.assert_called_once()
def test_get_default_block_config_invalid_node_type(self, workflow_service):
def test_get_default_block_config_invalid_node_type(self, workflow_service: WorkflowService):
"""Test get_default_block_config returns empty dict for invalid node type."""
with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping:
mock_mapping.return_value = {}
@ -1268,7 +1268,7 @@ class TestWorkflowService:
assert result == {}
def test_get_default_block_config_http_request_injects_default_config(self, workflow_service):
def test_get_default_block_config_http_request_injects_default_config(self, workflow_service: WorkflowService):
injected_config = HttpRequestNodeConfig(
max_connect_timeout=11,
max_read_timeout=22,
@ -1299,7 +1299,7 @@ class TestWorkflowService:
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config
def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service):
def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service: WorkflowService):
provided_config = HttpRequestNodeConfig(
max_connect_timeout=13,
max_read_timeout=23,
@ -1330,7 +1330,9 @@ class TestWorkflowService:
passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"]
assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is provided_config
def test_get_default_block_config_http_request_malformed_config_raises_type_error(self, workflow_service):
def test_get_default_block_config_http_request_malformed_config_raises_type_error(
self, workflow_service: WorkflowService
):
with (
patch(
"services.workflow_service.get_node_type_classes_mapping",
@ -1347,7 +1349,7 @@ class TestWorkflowService:
# ==================== Workflow Conversion Tests ====================
# These tests verify converting basic apps to workflow apps
def test_convert_to_workflow_from_chat_app(self, workflow_service):
def test_convert_to_workflow_from_chat_app(self, workflow_service: WorkflowService):
"""
Test convert_to_workflow converts chat app to workflow.
@ -1374,7 +1376,7 @@ class TestWorkflowService:
assert result == mock_new_app
mock_converter.convert_to_workflow.assert_called_once()
def test_convert_to_workflow_from_completion_app(self, workflow_service):
def test_convert_to_workflow_from_completion_app(self, workflow_service: WorkflowService):
"""
Test convert_to_workflow converts completion app to workflow.
@ -1395,7 +1397,7 @@ class TestWorkflowService:
assert result == mock_new_app
def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service):
def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service: WorkflowService):
"""
Test convert_to_workflow raises error for invalid app mode.