From d11e4eeaf793723b6de09556d1e617cbd6c2e428 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 9 Jun 2026 14:06:28 +0900 Subject: [PATCH] chore: DI current_user && use inspect (#37084) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/workflow.py | 99 ++++++++------- .../controllers/console/app/test_app_apis.py | 78 ++++++------ .../console/app/test_app_import_api.py | 38 +++--- .../service_api/dataset/test_dataset.py | 20 ++- .../controllers/service_api/test_site.py | 12 +- .../controllers/web/test_human_input_form.py | 2 +- .../test_human_input_resume_node_execution.py | 2 +- .../services/test_app_dsl_service.py | 22 ++-- .../test_human_input_delivery_test.py | 2 +- .../console/agent/test_agent_controllers.py | 45 +++---- .../console/app/test_app_import_api.py | 26 ++-- .../controllers/console/app/test_audio.py | 95 +++++++------- .../controllers/console/app/test_audio_api.py | 53 ++++---- .../console/app/test_conversation_api.py | 33 ++--- .../app/test_conversation_variables_api.py | 22 ++-- .../console/app/test_generator_api.py | 119 +++++++++--------- .../console/app/test_message_api.py | 9 -- .../console/app/test_model_config_api.py | 18 +-- .../console/app/test_statistic_api.py | 48 +++---- .../controllers/console/app/test_workflow.py | 103 +++++++-------- .../console/app/test_workflow_comment_api.py | 8 +- .../console/app/test_workflow_run_api.py | 15 +-- .../auth/test_data_source_bearer_auth.py | 13 +- .../console/auth/test_login_logout.py | 12 +- .../console/auth/test_oauth_server.py | 9 +- .../console/auth/test_token_refresh.py | 4 +- .../rag_pipeline/test_rag_pipeline.py | 31 ++--- .../test_rag_pipeline_workflow.py | 13 +- .../console/test_human_input_form.py | 25 ++-- .../controllers/console/test_remote_files.py | 27 ++-- .../openapi/test_human_input_form.py | 13 +- .../service_api/app/test_annotation.py | 28 ++--- .../controllers/service_api/app/test_audio.py | 17 +-- .../service_api/app/test_completion.py | 23 ++-- .../service_api/app/test_conversation.py | 31 ++--- .../controllers/service_api/app/test_file.py | 14 +-- .../service_api/app/test_hitl_service_api.py | 10 +- .../service_api/app/test_human_input_form.py | 16 +-- .../service_api/app/test_message.py | 27 ++-- .../service_api/app/test_workflow.py | 43 ++++--- .../service_api/app/test_workflow_events.py | 24 ++-- .../controllers/service_api/conftest.py | 8 -- .../service_api/dataset/test_metadata.py | 12 +- .../apps/workflow/test_app_config_manager.py | 2 +- .../services/test_workflow_service.py | 62 ++++----- 45 files changed, 576 insertions(+), 757 deletions(-) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 53744747d7..cb26963dbf 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -21,7 +21,12 @@ from controllers.common.schema import ( from controllers.console import console_ns from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model -from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + edit_permission_required, + setup_required, + with_current_user, +) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager @@ -54,7 +59,7 @@ from libs import helper from libs.datetime_utils import naive_utc_now from libs.helper import TimestampField, dump_response, to_timestamp, uuid_value from libs.login import current_account_with_tenant, login_required -from models import App +from models import Account, App from models.model import AppMode from models.workflow import Workflow from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX @@ -401,13 +406,12 @@ class DraftWorkflowApi(Resource): ) @console_ns.response(400, "Invalid workflow configuration") @console_ns.response(403, "Permission denied") + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Sync draft workflow """ - current_user, _ = current_account_with_tenant() - content_type = request.headers.get("Content-Type", "") if "application/json" in content_type: @@ -468,13 +472,12 @@ class AdvancedChatDraftWorkflowRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Run draft workflow """ - current_user, _ = current_account_with_tenant() - args_model = AdvancedChatWorkflowRunPayload.model_validate(console_ns.payload or {}) args = args_model.model_dump(exclude_none=True) @@ -514,12 +517,12 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Run draft workflow iteration node """ - current_user, _ = current_account_with_tenant() args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: @@ -552,12 +555,12 @@ class WorkflowDraftRunIterationNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Run draft workflow iteration node """ - current_user, _ = current_account_with_tenant() args = IterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) try: @@ -590,12 +593,12 @@ class AdvancedChatDraftRunLoopNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Run draft workflow loop node """ - current_user, _ = current_account_with_tenant() args = LoopNodeRunPayload.model_validate(console_ns.payload or {}) try: @@ -628,12 +631,12 @@ class WorkflowDraftRunLoopNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Run draft workflow loop node """ - current_user, _ = current_account_with_tenant() args = LoopNodeRunPayload.model_validate(console_ns.payload or {}) try: @@ -695,12 +698,12 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Preview human input form content and placeholders """ - current_user, _ = current_account_with_tenant() args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) inputs = args.inputs @@ -724,12 +727,12 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Submit human input form preview """ - current_user, _ = current_account_with_tenant() args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) workflow_service = WorkflowService() result = workflow_service.submit_human_input_form_preview( @@ -753,12 +756,12 @@ class WorkflowDraftHumanInputFormPreviewApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Preview human input form content and placeholders """ - current_user, _ = current_account_with_tenant() args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {}) inputs = args.inputs @@ -782,12 +785,12 @@ class WorkflowDraftHumanInputFormRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Submit human input form preview """ - current_user, _ = current_account_with_tenant() workflow_service = WorkflowService() args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {}) result = workflow_service.submit_human_input_form_preview( @@ -811,12 +814,12 @@ class WorkflowDraftHumanInputDeliveryTestApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Test human input delivery """ - current_user, _ = current_account_with_tenant() workflow_service = WorkflowService() args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {}) workflow_service.test_human_input_delivery( @@ -841,12 +844,12 @@ class DraftWorkflowRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Run draft workflow """ - current_user, _ = current_account_with_tenant() args = DraftWorkflowRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True) external_trace_id = get_external_trace_id(request) @@ -911,12 +914,12 @@ class DraftWorkflowNodeRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Run draft workflow node """ - current_user, _ = current_account_with_tenant() args_model = DraftWorkflowNodeRunPayload.model_validate(console_ns.payload or {}) args = args_model.model_dump(exclude_none=True) @@ -981,12 +984,12 @@ class PublishedWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Publish workflow """ - current_user, _ = current_account_with_tenant() args = PublishWorkflowPayload.model_validate(console_ns.payload or {}) @@ -1083,14 +1086,14 @@ class ConvertToWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.COMPLETION]) + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Convert basic mode of chatbot app to workflow mode Convert expert mode of chatbot app to workflow mode Convert Completion App to Workflow App """ - current_user, _ = current_account_with_tenant() payload = console_ns.payload or {} args = ConvertToWorkflowPayload.model_validate(payload).model_dump(exclude_none=True) @@ -1122,9 +1125,9 @@ class WorkflowFeaturesApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App): - current_user, _ = current_account_with_tenant() + def post(self, current_user: Account, app_model: App): args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {}) features = args.features @@ -1150,12 +1153,12 @@ class PublishedAllWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def get(self, app_model: App): + def get(self, current_user: Account, app_model: App): """ Get published workflows """ - current_user, _ = current_account_with_tenant() args = WorkflowListQuery.model_validate(request.args.to_dict(flat=True)) page = args.page @@ -1199,9 +1202,9 @@ class DraftWorkflowRestoreApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, workflow_id: str): - current_user, _ = current_account_with_tenant() + def post(self, current_user: Account, app_model: App, workflow_id: str): workflow_service = WorkflowService() try: @@ -1237,12 +1240,12 @@ class WorkflowByIdApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def patch(self, app_model: App, workflow_id: str): + def patch(self, current_user: Account, app_model: App, workflow_id: str): """ Update workflow attributes """ - current_user, _ = current_account_with_tenant() args = WorkflowUpdatePayload.model_validate(console_ns.payload or {}) # Prepare update data @@ -1355,12 +1358,12 @@ class DraftWorkflowTriggerRunApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Poll for trigger events and execute full workflow when event arrives """ - current_user, _ = current_account_with_tenant() args = DraftWorkflowTriggerRunPayload.model_validate(console_ns.payload or {}) node_id = args.node_id workflow_service = WorkflowService() @@ -1419,12 +1422,12 @@ class DraftWorkflowTriggerNodeApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App, node_id: str): + def post(self, current_user: Account, app_model: App, node_id: str): """ Poll for trigger events and execute single node when event arrives """ - current_user, _ = current_account_with_tenant() workflow_service = WorkflowService() draft_workflow = workflow_service.get_draft_workflow(app_model) @@ -1499,12 +1502,12 @@ class DraftWorkflowTriggerRunAllApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.WORKFLOW]) + @with_current_user @edit_permission_required - def post(self, app_model: App): + def post(self, current_user: Account, app_model: App): """ Full workflow debug when the start node is a trigger """ - current_user, _ = current_account_with_tenant() args = DraftWorkflowTriggerRunAllPayload.model_validate(console_ns.payload or {}) node_ids = args.node_ids diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py index 1baac42368..be13f993a1 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_apis.py @@ -3,6 +3,7 @@ from __future__ import annotations import uuid +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -68,15 +69,6 @@ from tests.test_containers_integration_tests.controllers.console.helpers import ) -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def _make_account() -> Account: account = Account( name="tester", @@ -108,7 +100,7 @@ class TestCompletionEndpoints: def test_completion_api_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( completion_module.AppGenerateService, @@ -125,13 +117,13 @@ class TestCompletionEndpoints: "/", json={"inputs": {}, "model_config": {}, "query": "hi"}, ): - resp = method(_make_account(), app_model=MagicMock(id="app-1")) + resp = method(api, _make_account(), app_model=MagicMock(id="app-1")) assert resp == {"result": {"text": "ok"}} def test_completion_api_conversation_not_exists(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( completion_module.AppGenerateService, @@ -146,11 +138,11 @@ class TestCompletionEndpoints: json={"inputs": {}, "model_config": {}, "query": "hi"}, ): with pytest.raises(NotFound): - method(_make_account(), app_model=MagicMock(id="app-1")) + method(api, _make_account(), app_model=MagicMock(id="app-1")) def test_completion_api_provider_not_initialized(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( completion_module.AppGenerateService, @@ -163,11 +155,11 @@ class TestCompletionEndpoints: json={"inputs": {}, "model_config": {}, "query": "hi"}, ): with pytest.raises(completion_module.ProviderNotInitializeError): - method(_make_account(), app_model=MagicMock(id="app-1")) + method(api, _make_account(), app_model=MagicMock(id="app-1")) def test_completion_api_quota_exceeded(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = completion_module.CompletionMessageApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( completion_module.AppGenerateService, @@ -180,7 +172,7 @@ class TestCompletionEndpoints: json={"inputs": {}, "model_config": {}, "query": "hi"}, ): with pytest.raises(completion_module.ProviderQuotaExceededError): - method(_make_account(), app_model=MagicMock(id="app-1")) + method(api, _make_account(), app_model=MagicMock(id="app-1")) class TestAppEndpoints: @@ -190,7 +182,7 @@ class TestAppEndpoints: def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppApi() - method = _unwrap(api.put) + method = unwrap(api.put) payload = { "name": "Updated App", "description": "Updated description", @@ -209,7 +201,7 @@ class TestAppEndpoints: app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), patch.object(type(console_ns), "payload", payload), ): - response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + response = method(api, app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) assert response == {"id": "app-1"} assert app_service.update_app.call_args.args[1]["icon_type"] is None @@ -228,7 +220,7 @@ class TestAppEndpoints: def test_app_icon_post_should_forward_icon_type(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = app_module.AppIconApi() - method = _unwrap(api.post) + method = unwrap(api.post) payload = { "icon": "https://example.com/icon.png", "icon_type": "image", @@ -246,7 +238,7 @@ class TestAppEndpoints: app.test_request_context("/console/api/apps/app-1/icon", method="POST", json=payload), patch.object(type(console_ns), "payload", payload), ): - response = method(app_model=SimpleNamespace()) + response = method(api, app_model=SimpleNamespace()) assert response == {"id": "app-1"} assert app_service.update_app_icon.call_args.args[1:] == ( @@ -300,7 +292,7 @@ class TestOpsTraceEndpoints: def test_trace_app_config_get_empty(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr( ops_trace_module.OpsService, @@ -309,13 +301,13 @@ class TestOpsTraceEndpoints: ) with app.test_request_context("/?tracing_provider=langfuse"): - result = method(app_model=MagicMock(id="app-1")) + result = method(api, app_model=MagicMock(id="app-1")) assert result == {"has_not_configured": True} def test_trace_app_config_post_invalid(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( ops_trace_module.OpsService, @@ -328,11 +320,11 @@ class TestOpsTraceEndpoints: json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}}, ): with pytest.raises(BadRequest): - method(app_model=MagicMock(id="app-1")) + method(api, app_model=MagicMock(id="app-1")) def test_trace_app_config_delete_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = ops_trace_module.TraceAppConfigApi() - method = _unwrap(api.delete) + method = unwrap(api.delete) monkeypatch.setattr( ops_trace_module.OpsService, @@ -342,7 +334,7 @@ class TestOpsTraceEndpoints: with app.test_request_context("/?tracing_provider=langfuse"): with pytest.raises(BadRequest): - method(app_model=MagicMock(id="app-1")) + method(api, app_model=MagicMock(id="app-1")) class TestSiteEndpoints: @@ -360,7 +352,7 @@ class TestSiteEndpoints: def test_app_site_update_post(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSite() - method = _unwrap(api.post) + method = unwrap(api.post) site = MagicMock() site.app_id = "app-1" @@ -386,14 +378,14 @@ class TestSiteEndpoints: monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") with app.test_request_context("/", json={"title": "My Site"}): - result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1")) + result = method(api, SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1")) assert isinstance(result, dict) assert result["title"] == "My Site" def test_app_site_access_token_reset(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = site_module.AppSiteAccessTokenReset() - method = _unwrap(api.post) + method = unwrap(api.post) site = MagicMock() site.app_id = "app-1" @@ -420,7 +412,7 @@ class TestSiteEndpoints: monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") with app.test_request_context("/"): - result = method(SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1")) + result = method(api, SimpleNamespace(id="u1"), app_model=SimpleNamespace(id="app-1")) assert isinstance(result, dict) assert result["access_token"] == "code" @@ -451,7 +443,7 @@ class TestWorkflowAppLogEndpoints: def test_workflow_app_log_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_app_log_module.WorkflowAppLogApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock())) @@ -481,7 +473,7 @@ class TestWorkflowAppLogEndpoints: ) with app.test_request_context("/?page=1&limit=20"): - result = method(app_model=SimpleNamespace(id="app-1")) + result = method(api, app_model=SimpleNamespace(id="app-1")) assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} @@ -497,7 +489,7 @@ class TestWorkflowDraftVariableEndpoints: def test_workflow_variable_collection_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_draft_variable_module.WorkflowVariableCollectionApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock())) @@ -532,7 +524,7 @@ class TestWorkflowDraftVariableEndpoints: monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService) with app.test_request_context("/?page=1&limit=20"): - result = method(_make_account(), app_model=SimpleNamespace(id="app-1")) + result = method(api, _make_account(), app_model=SimpleNamespace(id="app-1")) assert result == {"items": [], "total": 0} @@ -565,10 +557,12 @@ class TestWorkflowStatisticEndpoints: ) api = workflow_statistic_module.WorkflowDailyRunsStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) with app.test_request_context("/"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + response = method( + api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1") + ) assert response.get_json() == {"data": [{"date": "2024-01-01"}]} @@ -588,10 +582,12 @@ class TestWorkflowStatisticEndpoints: ) api = workflow_statistic_module.WorkflowDailyTerminalsStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) with app.test_request_context("/"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + response = method( + api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(tenant_id="t1", id="app-1") + ) assert response.get_json() == {"data": [{"date": "2024-01-02"}]} @@ -610,7 +606,7 @@ class TestWorkflowTriggerEndpoints: def test_webhook_trigger_api_get(self, app: Flask, monkeypatch: pytest.MonkeyPatch): api = workflow_trigger_module.WebhookTriggerApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock())) @@ -635,7 +631,7 @@ class TestWorkflowTriggerEndpoints: monkeypatch.setattr(workflow_trigger_module, "sessionmaker", DummySessionMaker) with app.test_request_context("/?node_id=node-1"): - result = method(app_model=SimpleNamespace(id="app-1")) + result = method(api, app_model=SimpleNamespace(id="app-1")) assert isinstance(result, dict) assert {"id", "webhook_id", "webhook_url", "webhook_debug_url", "node_id", "created_at"} <= set(result.keys()) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index 520ee67ee0..6ac5e9e93f 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -2,6 +2,7 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock @@ -12,15 +13,6 @@ from controllers.console.app import app_import as app_import_module from services.app_dsl_service import ImportStatus -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - class _Result: def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): self.status = status @@ -42,7 +34,7 @@ class TestAppImportApi: def test_import_post_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=False) monkeypatch.setattr( @@ -52,14 +44,14 @@ class TestAppImportApi: ) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) assert status == 400 assert response["status"] == ImportStatus.FAILED def test_import_post_returns_pending_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=False) monkeypatch.setattr( @@ -69,14 +61,14 @@ class TestAppImportApi: ) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) assert status == 202 assert response["status"] == ImportStatus.PENDING def test_import_post_updates_webapp_auth_when_enabled(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=True) monkeypatch.setattr( @@ -88,7 +80,7 @@ class TestAppImportApi: monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) update_access.assert_called_once_with("app-123", "private") assert status == 200 @@ -96,7 +88,7 @@ class TestAppImportApi: def test_import_post_commits_session_on_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=False) monkeypatch.setattr( @@ -111,7 +103,7 @@ class TestAppImportApi: monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) fake_session.commit.assert_called_once_with() fake_session.rollback.assert_not_called() @@ -120,7 +112,7 @@ class TestAppImportApi: def test_import_post_rolls_back_session_on_failure(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportApi() - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=False) monkeypatch.setattr( @@ -135,7 +127,7 @@ class TestAppImportApi: monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) fake_session.rollback.assert_called_once_with() fake_session.commit.assert_not_called() @@ -150,7 +142,7 @@ class TestAppImportConfirmApi: def test_import_confirm_returns_failed_status(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportConfirmApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( app_import_module.AppDslService, @@ -159,7 +151,7 @@ class TestAppImportConfirmApi: ) with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): - response, status = method(SimpleNamespace(id="u1"), import_id="import-1") + response, status = method(api, SimpleNamespace(id="u1"), import_id="import-1") assert status == 400 assert response["status"] == ImportStatus.FAILED @@ -172,7 +164,7 @@ class TestAppImportCheckDependenciesApi: def test_import_check_dependencies_returns_result(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = app_import_module.AppImportCheckDependenciesApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr( app_import_module.AppDslService, @@ -181,7 +173,7 @@ class TestAppImportCheckDependenciesApi: ) with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"): - response, status = method(app_model=SimpleNamespace(id="app-1")) + response, status = method(api, app_model=SimpleNamespace(id="app-1")) assert status == 200 assert response["leaked_dependencies"] == [] diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 6d35655817..91b0055e06 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -239,13 +239,7 @@ class TestTagUnbindingPayload: # Helpers # --------------------------------------------------------------------------- - -def _unwrap(method): - """Walk ``__wrapped__`` chain to get the original function.""" - fn = method - while hasattr(fn, "__wrapped__"): - fn = fn.__wrapped__ - return fn +from inspect import unwrap @pytest.fixture @@ -499,7 +493,7 @@ class TestDatasetListApiPost: json={"name": "New Dataset"}, ): api = DatasetListApi() - response, status = _unwrap(api.post)(api, tenant_id=mock_tenant.id) + response, status = unwrap(api.post)(api, tenant_id=mock_tenant.id) assert status == 200 assert_dataset_detail_shape(response) @@ -527,7 +521,7 @@ class TestDatasetListApiPost: ): api = DatasetListApi() with pytest.raises(DatasetNameDuplicateError): - _unwrap(api.post)(api, tenant_id=mock_tenant.id) + unwrap(api.post)(api, tenant_id=mock_tenant.id) # --------------------------------------------------------------------------- @@ -720,7 +714,7 @@ class TestDatasetApiPatch: json=payload, ): api = DatasetApi() - response, status = _unwrap(api.patch)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) + response, status = unwrap(api.patch)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) assert status == 200 assert_dataset_detail_shape(response, with_partial_members=True) @@ -760,7 +754,7 @@ class TestDatasetApiDelete: method="DELETE", ): api = DatasetApi() - result = _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) + result = unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) assert result == ("", 204) @@ -783,7 +777,7 @@ class TestDatasetApiDelete: ): api = DatasetApi() with pytest.raises(NotFound): - _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) + unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") @@ -804,7 +798,7 @@ class TestDatasetApiDelete: ): api = DatasetApi() with pytest.raises(DatasetInUseError): - _unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) + unwrap(api.delete)(api, _=mock_dataset.tenant_id, dataset_id=mock_dataset.id) # --------------------------------------------------------------------------- diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py index 4e884626a7..c1b20cd02b 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py @@ -19,11 +19,7 @@ def app(flask_app_with_containers) -> Flask: return flask_app_with_containers -def _unwrap(method): - fn = method - while hasattr(fn, "__wrapped__"): - fn = fn.__wrapped__ - return fn +from inspect import unwrap def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: @@ -76,7 +72,7 @@ class TestAppSiteApi: with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): api = AppSiteApi() - response = _unwrap(api.get)(api, app_model=app_model) + response = unwrap(api.get)(api, app_model=app_model) assert response["title"] == "Service API Site" assert response["icon"] == "robot" @@ -89,7 +85,7 @@ class TestAppSiteApi: with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): api = AppSiteApi() with pytest.raises(Forbidden): - _unwrap(api.get)(api, app_model=app_model) + unwrap(api.get)(api, app_model=app_model) def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None: tenant = _create_tenant(db_session_with_containers) @@ -107,4 +103,4 @@ class TestAppSiteApi: with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): api = AppSiteApi() with pytest.raises(Forbidden): - _unwrap(api.get)(api, app_model=app_model) + unwrap(api.get)(api, app_model=app_model) diff --git a/api/tests/test_containers_integration_tests/controllers/web/test_human_input_form.py b/api/tests/test_containers_integration_tests/controllers/web/test_human_input_form.py index c93b3dcb48..b1a8bb3394 100644 --- a/api/tests/test_containers_integration_tests/controllers/web/test_human_input_form.py +++ b/api/tests/test_containers_integration_tests/controllers/web/test_human_input_form.py @@ -62,7 +62,7 @@ def _create_app_with_site(session: Session) -> tuple[App, Account]: tenant_id=tenant.id, name="Test App", description="", - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, icon_type="emoji", icon="app", icon_background="#ffffff", diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 103fe88df7..7efb78fb04 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -205,7 +205,7 @@ class TestHumanInputResumeNodeExecutionIntegration: tenant_id=tenant.id, name="Test App", description="", - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, icon_type=IconType.EMOJI.value, icon="rocket", icon_background="#4ECDC4", diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index de6117945e..85378bd84d 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -75,7 +75,7 @@ def _app_stub(**overrides: Any) -> App: defaults = { "id": str(uuid4()), "tenant_id": _DEFAULT_TENANT_ID, - "mode": AppMode.WORKFLOW.value, + "mode": AppMode.WORKFLOW, "name": "n", "description": "d", "icon_type": IconType.EMOJI, @@ -528,7 +528,7 @@ class TestAppDslService: created_app = SimpleNamespace( id=str(uuid4()), - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, tenant_id=_DEFAULT_TENANT_ID, ) monkeypatch.setattr( @@ -707,7 +707,7 @@ class TestAppDslService: ) app = _app_stub( - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, name="old", description="old-desc", icon_type=IconType.EMOJI, @@ -721,7 +721,7 @@ class TestAppDslService: app=app, data={ "app": { - "mode": AppMode.WORKFLOW.value, + "mode": AppMode.WORKFLOW, "name": "yaml-name", "icon_type": IconType.IMAGE, "icon": "X", @@ -747,7 +747,7 @@ class TestAppDslService: with pytest.raises(ValueError, match="Current tenant is not set"): service._create_or_update_app( app=None, - data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + data={"app": {"mode": AppMode.WORKFLOW, "name": "n"}}, account=account, ) @@ -772,7 +772,7 @@ class TestAppDslService: ) ] data = { - "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "app": {"mode": AppMode.WORKFLOW, "name": "n"}, "workflow": { "graph": {"nodes": []}, "features": {}, @@ -792,8 +792,8 @@ class TestAppDslService: service = AppDslService(db_session_with_containers) with pytest.raises(ValueError, match="Missing workflow data"): service._create_or_update_app( - app=_app_stub(mode=AppMode.WORKFLOW.value), - data={"app": {"mode": AppMode.WORKFLOW.value}}, + app=_app_stub(mode=AppMode.WORKFLOW), + data={"app": {"mode": AppMode.WORKFLOW}}, account=_account_mock(), ) @@ -852,7 +852,7 @@ class TestAppDslService: ) workflow_app = _app_stub( - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, icon_type="emoji", ) AppDslService.export_dsl(workflow_app) @@ -874,7 +874,7 @@ class TestAppDslService: ) emoji_app = _app_stub( - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, name="Emoji App", icon="🎨", icon_type=IconType.EMOJI, @@ -889,7 +889,7 @@ class TestAppDslService: assert data["app"]["icon_background"] == "#FF5733" image_app = _app_stub( - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, name="Image App", icon="https://example.com/icon.png", icon_type=IconType.IMAGE, diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 3e5905efbb..9241d5bcc9 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -50,7 +50,7 @@ def _create_app_with_draft_workflow( tenant_id=tenant.id, name="Test App", description="", - mode=AppMode.WORKFLOW.value, + mode=AppMode.WORKFLOW, icon_type="emoji", icon="app", icon_background="#ffffff", diff --git a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py index a1f567de74..bb1c25d2e8 100644 --- a/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py +++ b/api/tests/unit_tests/controllers/console/agent/test_agent_controllers.py @@ -1,3 +1,4 @@ +from inspect import unwrap from types import SimpleNamespace from typing import Protocol, cast @@ -26,12 +27,6 @@ from controllers.console.agent.roster import ( from services.entities.agent_entities import ComposerSaveStrategy, ComposerVariant -def _unwrap(method): - while hasattr(method, "__wrapped__"): - method = method.__wrapped__ - return method - - def _agent_response(agent_id: str = "agent-1") -> dict: return { "id": agent_id, @@ -135,7 +130,7 @@ def test_roster_list_get_parses_query_and_calls_service(app: Flask, monkeypatch: monkeypatch.setattr(roster_controller.AgentRosterService, "list_roster_agents", list_roster_agents) with app.test_request_context("/console/api/agents?page=2&limit=5&keyword=analyst"): - result = _unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1") + result = unwrap(AgentRosterListApi.get)(AgentRosterListApi(), "tenant-1") assert result["page"] == 2 assert captured == {"tenant_id": "tenant-1", "page": 2, "limit": 5, "keyword": "analyst"} @@ -157,7 +152,7 @@ def test_roster_list_post_creates_agent_and_returns_detail( ) with app.test_request_context(json={"name": "Analyst", "agent_soul": {"prompt": {"system_prompt": "x"}}}): - result, status = _unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id) + result, status = unwrap(AgentRosterListApi.post)(AgentRosterListApi(), "tenant-1", account_id) assert status == 201 assert result["id"] == "agent-1" @@ -174,7 +169,7 @@ def test_invite_options_get_parses_app_id(app: Flask, monkeypatch: pytest.Monkey monkeypatch.setattr(roster_controller.AgentRosterService, "list_invite_options", list_invite_options) with app.test_request_context("/console/api/agents/invite-options?page=1&limit=10&app_id=app-1"): - result = _unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1") + result = unwrap(AgentInviteOptionsApi.get)(AgentInviteOptionsApi(), "tenant-1") assert result == {"data": [], "page": 1, "limit": 10, "total": 0, "has_more": False} assert captured == {"tenant_id": "tenant-1", "page": 1, "limit": 10, "keyword": None, "app_id": "app-1"} @@ -232,19 +227,19 @@ def test_roster_detail_patch_delete_and_versions_call_services( }, ) - assert _unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id + assert unwrap(AgentRosterDetailApi.get)(AgentRosterDetailApi(), "tenant-1", agent_id)["id"] == agent_id with app.test_request_context(json={"description": "updated"}): assert ( - _unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"] + unwrap(AgentRosterDetailApi.patch)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id)["description"] == "updated" ) - assert _unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204) + assert unwrap(AgentRosterDetailApi.delete)(AgentRosterDetailApi(), "tenant-1", account_id, agent_id) == ("", 204) assert archived["account_id"] == "account-1" assert ( - _unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"] + unwrap(AgentRosterVersionsApi.get)(AgentRosterVersionsApi(), "tenant-1", agent_id)["data"][0]["id"] == "version-1" ) - version_detail = _unwrap(AgentRosterVersionDetailApi.get)( + version_detail = unwrap(AgentRosterVersionDetailApi.get)( AgentRosterVersionDetailApi(), "tenant-1", agent_id, version_id ) assert version_detail["id"] == version_id @@ -286,27 +281,27 @@ def test_workflow_composer_get_put_validate_candidates_impact_and_save( }, ) - workflow_state = _unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1") + workflow_state = unwrap(WorkflowAgentComposerApi.get)(WorkflowAgentComposerApi(), "tenant-1", app_model, "node-1") assert workflow_state["node_id"] == "node-1" with app.test_request_context(json=payload): - saved_state = _unwrap(WorkflowAgentComposerApi.put)( + saved_state = unwrap(WorkflowAgentComposerApi.put)( WorkflowAgentComposerApi(), "tenant-1", account_id, app_model, "node-1" ) assert saved_state["save_options"] == ["node_job_only"] - assert _unwrap(WorkflowAgentComposerValidateApi.post)( + assert unwrap(WorkflowAgentComposerValidateApi.post)( WorkflowAgentComposerValidateApi(), app_model, "node-1" ) == {"result": "success", "errors": []} assert ( - _unwrap(WorkflowAgentComposerCandidatesApi.get)(WorkflowAgentComposerCandidatesApi(), app_model, "node-1")[ + unwrap(WorkflowAgentComposerCandidatesApi.get)(WorkflowAgentComposerCandidatesApi(), app_model, "node-1")[ "variant" ] == "workflow" ) with app.test_request_context(json=payload): - assert _unwrap(WorkflowAgentComposerImpactApi.post)( + assert unwrap(WorkflowAgentComposerImpactApi.post)( WorkflowAgentComposerImpactApi(), "tenant-1", app_model, "node-1" ) == {"current_snapshot_id": "version-1", "workflow_node_count": 1, "bindings": []} - assert _unwrap(WorkflowAgentComposerSaveToRosterApi.post)( + assert unwrap(WorkflowAgentComposerSaveToRosterApi.post)( WorkflowAgentComposerSaveToRosterApi(), "tenant-1", account_id, app_model, "node-1" )["save_options"] == ["node_job_only"] @@ -315,7 +310,7 @@ def test_workflow_impact_returns_empty_without_version(app: Flask) -> None: payload = {"variant": ComposerVariant.WORKFLOW.value, "save_strategy": ComposerSaveStrategy.NODE_JOB_ONLY.value} with app.test_request_context(json=payload): - result = _unwrap(WorkflowAgentComposerImpactApi.post)( + result = unwrap(WorkflowAgentComposerImpactApi.post)( WorkflowAgentComposerImpactApi(), "tenant-1", SimpleNamespace(id="app-1"), "node-1" ) @@ -348,15 +343,15 @@ def test_agent_app_composer_get_put_validate_and_candidates( lambda **kwargs: _candidates_response("agent_app"), ) - assert _unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app" + assert unwrap(AgentAppComposerApi.get)(AgentAppComposerApi(), "tenant-1", app_model)["variant"] == "agent_app" with app.test_request_context(json=payload): assert ( - _unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"] + unwrap(AgentAppComposerApi.put)(AgentAppComposerApi(), "tenant-1", account_id, app_model)["variant"] == "agent_app" ) - assert _unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == { + assert unwrap(AgentAppComposerValidateApi.post)(AgentAppComposerValidateApi(), app_model) == { "result": "success", "errors": [], } - agent_app_candidates = _unwrap(AgentAppComposerCandidatesApi.get)(AgentAppComposerCandidatesApi(), app_model) + agent_app_candidates = unwrap(AgentAppComposerCandidatesApi.get)(AgentAppComposerCandidatesApi(), app_model) assert agent_app_candidates["variant"] == "agent_app" diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py index 386f75e231..0cccb34b08 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -2,6 +2,7 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock @@ -12,15 +13,6 @@ from controllers.console.app import app_import as app_import_module from services.app_dsl_service import ImportStatus -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - class _Result: def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): self.status = status @@ -52,7 +44,7 @@ class TestAppImportApi: def test_import_post_returns_failed_status_and_rolls_back( self, api, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=False) session = _mock_session(monkeypatch) @@ -63,7 +55,7 @@ class TestAppImportApi: ) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) session.rollback.assert_called_once_with() session.commit.assert_not_called() @@ -73,7 +65,7 @@ class TestAppImportApi: def test_import_post_returns_pending_status_and_commits( self, api, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=False) session = _mock_session(monkeypatch) @@ -84,7 +76,7 @@ class TestAppImportApi: ) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) session.commit.assert_called_once_with() session.rollback.assert_not_called() @@ -94,7 +86,7 @@ class TestAppImportApi: def test_import_post_updates_webapp_auth_when_enabled( self, api, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - method = _unwrap(api.post) + method = unwrap(api.post) _install_features(monkeypatch, enabled=True) session = _mock_session(monkeypatch) @@ -107,7 +99,7 @@ class TestAppImportApi: monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): - response, status = method(SimpleNamespace(id="u1")) + response, status = method(api, SimpleNamespace(id="u1")) session.commit.assert_called_once_with() session.rollback.assert_not_called() @@ -124,7 +116,7 @@ class TestAppImportConfirmApi: def test_import_confirm_returns_failed_status_and_rolls_back( self, api, app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - method = _unwrap(api.post) + method = unwrap(api.post) session = _mock_session(monkeypatch) monkeypatch.setattr( @@ -134,7 +126,7 @@ class TestAppImportConfirmApi: ) with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): - response, status = method(SimpleNamespace(id="u1"), import_id="import-1") + response, status = method(api, SimpleNamespace(id="u1"), import_id="import-1") session.rollback.assert_called_once_with() session.commit.assert_not_called() diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 2d218dac7e..82b9b68247 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -1,9 +1,11 @@ from __future__ import annotations import io +from inspect import unwrap from types import SimpleNamespace import pytest +from flask import Flask from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -32,27 +34,18 @@ from services.errors.audio import ( ) -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def _file_data(): return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav") -def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_console_audio_api_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) api = ChatMessageAudioApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): - response = handler(app_model=app_model) + response = handler(api, app_model=app_model) assert response == {"text": "ok"} @@ -71,33 +64,33 @@ def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None (InvokeError("invoke"), CompletionRequestError), ], ) -def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: +def test_console_audio_api_error_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) api = ChatMessageAudioApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): with pytest.raises(expected): - handler(app_model=app_model) + handler(api, app_model=app_model) -def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_console_audio_api_unhandled_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) api = ChatMessageAudioApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): with pytest.raises(InternalServerError): - handler(app_model=app_model) + handler(api, app_model=app_model) -def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_console_text_api_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) api = ChatMessageTextApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") with app.test_request_context( @@ -105,16 +98,16 @@ def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: method="POST", json={"text": "hello", "voice": "v"}, ): - response = handler(app_model=app_model) + response = handler(api, app_model=app_model) assert response == {"audio": "ok"} -def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_console_text_api_error_mapping(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())) api = ChatMessageTextApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") with app.test_request_context( @@ -123,23 +116,23 @@ def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> json={"text": "hello"}, ): with pytest.raises(ProviderQuotaExceededError): - handler(app_model=app_model) + handler(api, app_model=app_model) -def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_console_text_modes_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) api = TextModesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(tenant_id="t1") with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"): - response = handler(app_model=app_model) + response = handler(api, app_model=app_model) assert response == ["voice-1"] -def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_console_text_modes_language_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( AudioService, "transcript_tts_voices", @@ -147,17 +140,17 @@ def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) ) api = TextModesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(tenant_id="t1") with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"): with pytest.raises(AppUnavailableError): - handler(app_model=app_model) + handler(api, app_model=app_model) -def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_audio_to_text_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = ChatMessageAudioApi() - method = _unwrap(api.post) + method = unwrap(api.post) response_payload = {"text": "hello"} monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload) @@ -171,14 +164,14 @@ def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: data=data, content_type="multipart/form-data", ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == response_payload -def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_audio_to_text_maps_audio_too_large(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = ChatMessageAudioApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( AudioService, @@ -196,12 +189,12 @@ def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch content_type="multipart/form-data", ): with pytest.raises(AudioTooLargeError): - method(app_model=app_model) + method(api, app_model=app_model) -def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = ChatMessageTextApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) @@ -212,14 +205,14 @@ def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: method="POST", json={"text": "hello"}, ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == {"audio": "ok"} -def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_voices_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = TextModesApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) @@ -230,14 +223,14 @@ def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> N method="GET", query_string={"language": "en-US"}, ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == ["voice-1"] -def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_audio_to_text_with_invalid_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = ChatMessageAudioApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"}) @@ -251,13 +244,13 @@ def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) - content_type="multipart/form-data", ): # Should not raise, AudioService is mocked - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == {"text": "test"} -def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_with_language_param(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = ChatMessageTextApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"}) @@ -268,13 +261,13 @@ def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) method="POST", json={"text": "hello", "language": "en-US"}, ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == {"audio": "test"} -def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_voices_with_language_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = TextModesApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr( AudioService, @@ -288,5 +281,5 @@ def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.Monk "/console/api/apps/app-1/text-to-audio/voices?language=en-US", method="GET", ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert isinstance(response, list) diff --git a/api/tests/unit_tests/controllers/console/app/test_audio_api.py b/api/tests/unit_tests/controllers/console/app/test_audio_api.py index 8b71837c29..40e6be1141 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio_api.py @@ -1,27 +1,20 @@ from __future__ import annotations import io +from inspect import unwrap from types import SimpleNamespace import pytest +from flask import Flask from controllers.console.app import audio as audio_module from controllers.console.app.error import AudioTooLargeError from services.errors.audio import AudioTooLargeServiceError -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - -def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_audio_to_text_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.ChatMessageAudioApi() - method = _unwrap(api.post) + method = unwrap(api.post) response_payload = {"text": "hello"} monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload) @@ -35,14 +28,14 @@ def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: data=data, content_type="multipart/form-data", ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == response_payload -def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_audio_to_text_maps_audio_too_large(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.ChatMessageAudioApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr( audio_module.AudioService, @@ -60,12 +53,12 @@ def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch content_type="multipart/form-data", ): with pytest.raises(AudioTooLargeError): - method(app_model=app_model) + method(api, app_model=app_model) -def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.ChatMessageTextApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) @@ -76,14 +69,14 @@ def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: method="POST", json={"text": "hello"}, ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == {"audio": "ok"} -def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_voices_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.TextModesApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) @@ -94,14 +87,14 @@ def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> N method="GET", query_string={"language": "en-US"}, ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == ["voice-1"] -def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_audio_to_text_with_invalid_file(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.ChatMessageAudioApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"}) @@ -115,13 +108,13 @@ def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) - content_type="multipart/form-data", ): # Should not raise, AudioService is mocked - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == {"text": "test"} -def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_with_language_param(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.ChatMessageTextApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"}) @@ -132,13 +125,13 @@ def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) method="POST", json={"text": "hello", "language": "en-US"}, ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert response == {"audio": "test"} -def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_text_to_audio_voices_with_language_filter(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = audio_module.TextModesApi() - method = _unwrap(api.get) + method = unwrap(api.get) monkeypatch.setattr( audio_module.AudioService, @@ -152,5 +145,5 @@ def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.Monk "/console/api/apps/app-1/text-to-audio/voices?language=en-US", method="GET", ): - response = method(app_model=app_model) + response = method(api, app_model=app_model) assert isinstance(response, list) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py index 41924bbfd3..5de07ff14e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -1,9 +1,11 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from werkzeug.exceptions import BadRequest, NotFound from controllers.console.app import conversation as conversation_module @@ -11,22 +13,13 @@ from models.model import AppMode from services.errors.conversation import ConversationNotExistsError -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def _make_account(): return SimpleNamespace(timezone="UTC", id="u1") -def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_completion_conversation_list_returns_paginated_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = conversation_module.CompletionConversationApi() - method = _unwrap(api.get) + method = unwrap(api.get) account = _make_account() monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) @@ -40,14 +33,14 @@ def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): - response = method(account, app_model=SimpleNamespace(id="app-1")) + response = method(api, account, app_model=SimpleNamespace(id="app-1")) assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} -def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_completion_conversation_list_invalid_time_range(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = conversation_module.CompletionConversationApi() - method = _unwrap(api.get) + method = unwrap(api.get) account = _make_account() monkeypatch.setattr( @@ -62,12 +55,12 @@ def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytes query_string={"start": "bad"}, ): with pytest.raises(BadRequest): - method(account, app_model=SimpleNamespace(id="app-1")) + method(api, account, app_model=SimpleNamespace(id="app-1")) -def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_chat_conversation_list_advanced_chat_calls_paginate(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = conversation_module.ChatConversationApi() - method = _unwrap(api.get) + method = unwrap(api.get) account = _make_account() monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) @@ -81,7 +74,7 @@ def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: p monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): - response = method(account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) + response = method(api, account, app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) assert response == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} @@ -114,7 +107,7 @@ def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPat def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None: api = conversation_module.CompletionConversationDetailApi() - method = _unwrap(api.delete) + method = unwrap(api.delete) monkeypatch.setattr( conversation_module.ConversationService, @@ -123,4 +116,4 @@ def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.Monke ) with pytest.raises(NotFound): - method(_make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1") + method(api, _make_account(), app_model=SimpleNamespace(id="app-1"), conversation_id="c1") diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py index 71b6a1aa37..ab3eacd03c 100644 --- a/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_variables_api.py @@ -2,6 +2,7 @@ from __future__ import annotations from contextlib import nullcontext from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace import pytest @@ -12,18 +13,9 @@ from controllers.console.app import conversation_variables as conversation_varia from graphon.variables.types import SegmentType -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def test_get_conversation_variables_returns_paginated_response(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = conversation_variables_module.ConversationVariablesApi() - method = _unwrap(api.get) + method = unwrap(api.get) created_at = datetime(2026, 1, 1, tzinfo=UTC) updated_at = datetime(2026, 1, 2, tzinfo=UTC) @@ -53,7 +45,7 @@ def test_get_conversation_variables_returns_paginated_response(app: Flask, monke method="GET", query_string={"conversation_id": "conv-1"}, ): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(api, app_model=SimpleNamespace(id="app-1")) assert response["page"] == 1 assert response["limit"] == 100 @@ -68,7 +60,7 @@ def test_get_conversation_variables_normalizes_value_type_and_value( app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: api = conversation_variables_module.ConversationVariablesApi() - method = _unwrap(api.get) + method = unwrap(api.get) row = SimpleNamespace( created_at=None, @@ -96,7 +88,7 @@ def test_get_conversation_variables_normalizes_value_type_and_value( method="GET", query_string={"conversation_id": "conv-1"}, ): - response = method(app_model=SimpleNamespace(id="app-1")) + response = method(api, app_model=SimpleNamespace(id="app-1")) assert response["data"][0]["value_type"] == "number" assert response["data"][0]["value"] == "42" @@ -104,8 +96,8 @@ def test_get_conversation_variables_normalizes_value_type_and_value( def test_get_conversation_variables_requires_conversation_id(app) -> None: api = conversation_variables_module.ConversationVariablesApi() - method = _unwrap(api.get) + method = unwrap(api.get) with app.test_request_context("/console/api/apps/app-1/conversation-variables", method="GET"): with pytest.raises(ValidationError): - method(app_model=SimpleNamespace(id="app-1")) + method(api, app_model=SimpleNamespace(id="app-1")) diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py index 0bf4215244..308089b848 100644 --- a/api/tests/unit_tests/controllers/console/app/test_generator_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -1,24 +1,17 @@ from __future__ import annotations +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from flask import Flask from controllers.console.app import generator as generator_module from controllers.console.app.error import ProviderNotInitializeError from core.errors.error import ProviderTokenNotInitError -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def _model_config_payload(): return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}} @@ -38,9 +31,9 @@ def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow): return service -def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_rule_generate_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.RuleGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []}) @@ -49,14 +42,14 @@ def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: method="POST", json={"instruction": "do it", "model_config": _model_config_payload()}, ): - response = method("t1") + response = method(api, "t1") assert response == {"rules": []} -def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_rule_code_generate_maps_token_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.RuleCodeGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) def _raise(*_args, **_kwargs): raise ProviderTokenNotInitError("missing token") @@ -69,12 +62,12 @@ def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatc json={"instruction": "do it", "model_config": _model_config_payload()}, ): with pytest.raises(ProviderNotInitializeError): - method("t1") + method(api, "t1") -def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_instruction_generate_app_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) session = MagicMock() session.get.return_value = None @@ -89,16 +82,16 @@ def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch "model_config": _model_config_payload(), }, ): - response, status = method(session, "t1") + response, status = method(api, session, "t1") assert status == 400 assert response["error"] == "app app-1 not found" session.get.assert_called_once_with(generator_module.App, "app-1") -def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_instruction_generate_workflow_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) app_model = SimpleNamespace(id="app-1") session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model) @@ -114,15 +107,15 @@ def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.Monkey "model_config": _model_config_payload(), }, ): - response, status = method(session, "t1") + response, status = method(api, session, "t1") assert status == 400 assert response["error"] == "workflow app-1 not found" -def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_instruction_generate_node_missing(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) app_model = SimpleNamespace(id="app-1") session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model) @@ -140,15 +133,15 @@ def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) "model_config": _model_config_payload(), }, ): - response, status = method(session, "t1") + response, status = method(api, session, "t1") assert status == 400 assert response["error"] == "node node-1 not found" -def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_instruction_generate_code_node(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) app_model = SimpleNamespace(id="app-1") session = SimpleNamespace(get=lambda *_args, **_kwargs: app_model) @@ -173,16 +166,16 @@ def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> "model_config": _model_config_payload(), }, ): - response = method(session, "t1") + response = method(api, session, "t1") assert response == {"code": "x"} assert workflow_service.app_model is app_model assert workflow_service.session is session -def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_instruction_generate_legacy_modify(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) session = SimpleNamespace() monkeypatch.setattr( @@ -202,14 +195,14 @@ def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch "model_config": _model_config_payload(), }, ): - response = method(session, "t1") + response = method(api, session, "t1") assert response == {"instruction": "ok"} -def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_instruction_generate_incompatible_params(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.InstructionGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) session = SimpleNamespace() with app.test_request_context( @@ -223,29 +216,29 @@ def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.Monke "model_config": _model_config_payload(), }, ): - response, status = method(session, "t1") + response, status = method(api, session, "t1") assert status == 400 assert response["error"] == "incompatible parameters" -def test_instruction_template_prompt(app) -> None: +def test_instruction_template_prompt(app: Flask) -> None: api = generator_module.InstructionGenerationTemplateApi() - method = _unwrap(api.post) + method = unwrap(api.post) with app.test_request_context( "/console/api/instruction-generate/template", method="POST", json={"type": "prompt"}, ): - response = method() + response = method(api) assert "data" in response -def test_instruction_template_invalid_type(app) -> None: +def test_instruction_template_invalid_type(app: Flask) -> None: api = generator_module.InstructionGenerationTemplateApi() - method = _unwrap(api.post) + method = unwrap(api.post) with app.test_request_context( "/console/api/instruction-generate/template", @@ -253,7 +246,7 @@ def test_instruction_template_invalid_type(app) -> None: json={"type": "unknown"}, ): with pytest.raises(ValueError): - method() + method(api) # ─ /workflow-generate ───────────────────────────────────────────────────────── @@ -281,9 +274,9 @@ def _stub_workflow_service(monkeypatch: pytest.MonkeyPatch, returns=None, raises monkeypatch.setattr(generator_module.WorkflowGeneratorService, "generate_workflow_graph", _call) -def test_workflow_generate_returns_service_result(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_returns_service_result(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) expected = { "graph": {"nodes": [{"id": "node-1"}], "edges": [], "viewport": {"x": 0, "y": 0, "zoom": 0.7}}, @@ -297,16 +290,16 @@ def test_workflow_generate_returns_service_result(app, monkeypatch: pytest.Monke method="POST", json=_workflow_generate_payload(), ): - response = method("t1") + response = method(api, "t1") assert response == expected -def test_workflow_generate_maps_provider_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_maps_provider_token_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """ProviderTokenNotInitError → ProviderNotInitializeError so the frontend can render the same "provider missing" UX as /rule-generate.""" api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) _stub_workflow_service(monkeypatch, raises=ProviderTokenNotInitError("missing token")) @@ -316,15 +309,15 @@ def test_workflow_generate_maps_provider_token_error(app, monkeypatch: pytest.Mo json=_workflow_generate_payload(), ): with pytest.raises(ProviderNotInitializeError): - method("t1") + method(api, "t1") -def test_workflow_generate_maps_quota_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_maps_quota_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: from controllers.console.app.error import ProviderQuotaExceededError from core.errors.error import QuotaExceededError api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) _stub_workflow_service(monkeypatch, raises=QuotaExceededError()) @@ -334,15 +327,15 @@ def test_workflow_generate_maps_quota_error(app, monkeypatch: pytest.MonkeyPatch json=_workflow_generate_payload(), ): with pytest.raises(ProviderQuotaExceededError): - method("t1") + method(api, "t1") -def test_workflow_generate_maps_model_not_support_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_maps_model_not_support_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: from controllers.console.app.error import ProviderModelCurrentlyNotSupportError from core.errors.error import ModelCurrentlyNotSupportError api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) _stub_workflow_service(monkeypatch, raises=ModelCurrentlyNotSupportError("not supported")) @@ -352,15 +345,15 @@ def test_workflow_generate_maps_model_not_support_error(app, monkeypatch: pytest json=_workflow_generate_payload(), ): with pytest.raises(ProviderModelCurrentlyNotSupportError): - method("t1") + method(api, "t1") -def test_workflow_generate_maps_invoke_error(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_maps_invoke_error(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: from controllers.console.app.error import CompletionRequestError from graphon.model_runtime.errors.invoke import InvokeError api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) _stub_workflow_service(monkeypatch, raises=InvokeError("LLM unreachable")) @@ -370,13 +363,13 @@ def test_workflow_generate_maps_invoke_error(app, monkeypatch: pytest.MonkeyPatc json=_workflow_generate_payload(), ): with pytest.raises(CompletionRequestError): - method("t1") + method(api, "t1") -def test_workflow_generate_accepts_advanced_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_accepts_advanced_chat_mode(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """The payload Literal must accept advanced-chat as well as workflow.""" api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) captured: dict = {} @@ -397,17 +390,17 @@ def test_workflow_generate_accepts_advanced_chat_mode(app, monkeypatch: pytest.M method="POST", json=payload, ): - method("t1") + method(api, "t1") assert captured["mode"] == "advanced-chat" assert captured["instruction"] == "Summarize a URL" assert captured["ideal_output"] == "A 3-sentence summary." -def test_workflow_generate_forwards_current_graph_for_refine(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_forwards_current_graph_for_refine(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """cmd+k `/refine`: the optional current_graph field reaches the service.""" api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) captured: dict = {} @@ -429,15 +422,15 @@ def test_workflow_generate_forwards_current_graph_for_refine(app, monkeypatch: p method="POST", json=payload, ): - method("t1") + method(api, "t1") assert captured["current_graph"] == graph -def test_workflow_generate_current_graph_defaults_to_none(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_generate_current_graph_defaults_to_none(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: """Omitting current_graph (the `/create` path) forwards None to the service.""" api = generator_module.WorkflowGenerateApi() - method = _unwrap(api.post) + method = unwrap(api.post) captured: dict = {} @@ -456,6 +449,6 @@ def test_workflow_generate_current_graph_defaults_to_none(app, monkeypatch: pyte method="POST", json=_workflow_generate_payload(), ): - method("t1") + method(api, "t1") assert captured["current_graph"] is None diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py index c984dbef5d..27bc5e341e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_message_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -7,15 +7,6 @@ import pytest from controllers.console.app import message as message_module -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None: """Test valid ChatMessagesQuery with all fields.""" query = message_module.ChatMessagesQuery( diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py index 5fc60d8046..714eae618e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock @@ -11,18 +12,9 @@ from controllers.console.app import model_config as model_config_module from models.model import AppMode, AppModelConfig -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = model_config_module.ModelConfigResource() - method = _unwrap(api.post) + method = unwrap(api.post) app_model = SimpleNamespace( id="app-1", @@ -50,7 +42,7 @@ def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest. monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): - response = method("t1", "u1", app_model=app_model) + response = method(api, "t1", "u1", app_model=app_model) session.add.assert_called_once() session.flush.assert_called_once() @@ -62,7 +54,7 @@ def test_post_updates_app_model_config_for_chat(app: Flask, monkeypatch: pytest. def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = model_config_module.ModelConfigResource() - method = _unwrap(api.post) + method = unwrap(api.post) app_model = SimpleNamespace( id="app-1", @@ -137,7 +129,7 @@ def test_post_encrypts_agent_tool_parameters(app: Flask, monkeypatch: pytest.Mon monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): - response = method("t1", "u1", app_model=app_model) + response = method(api, "t1", "u1", app_model=app_model) stored_config = session.add.call_args[0][0] stored_agent_mode = json.loads(stored_config.agent_mode) diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py index 4093398341..b31dccd034 100644 --- a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py @@ -1,6 +1,7 @@ from __future__ import annotations from decimal import Decimal +from inspect import unwrap from types import SimpleNamespace import pytest @@ -9,15 +10,6 @@ from werkzeug.exceptions import BadRequest from controllers.console.app import statistic as statistic_module -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func - - class _ConnContext: def __init__(self, rows): self._rows = rows @@ -48,42 +40,42 @@ def _install_common(monkeypatch: pytest.MonkeyPatch) -> None: def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [SimpleNamespace(date="2024-01-01", message_count=3)] _install_common(monkeypatch) _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]} def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyConversationStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] _install_common(monkeypatch) _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyTokenCostStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")] _install_common(monkeypatch) _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) data = response.get_json() assert len(data["data"]) == 1 @@ -94,14 +86,14 @@ def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.Monkey def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyTerminalsStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)] _install_common(monkeypatch) _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]} @@ -111,13 +103,13 @@ def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypat # This just verifies the decorator is applied correctly # Actual endpoint testing would require complex JOIN mocking api = statistic_module.AverageSessionInteractionStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) assert callable(method) def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) def mock_parse(*args, **kwargs): raise ValueError("Invalid time range") @@ -128,12 +120,12 @@ def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytes with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): with pytest.raises(BadRequest): - method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [ SimpleNamespace(date="2024-01-01", message_count=10), @@ -144,7 +136,7 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) data = response.get_json() assert len(data["data"]) == 3 @@ -152,20 +144,20 @@ def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPa def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyMessageStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) _install_common(monkeypatch) _install_db(monkeypatch, []) with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": []} def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyConversationStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] _install_db(monkeypatch, rows) @@ -177,14 +169,14 @@ def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.M monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None: api = statistic_module.DailyTokenCostStatistic() - method = _unwrap(api.get) + method = unwrap(api.get) rows = [ SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"), @@ -194,7 +186,7 @@ def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.Monk _install_db(monkeypatch, rows) with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): - response = method(SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) + response = method(api, SimpleNamespace(timezone="UTC"), app_model=SimpleNamespace(id="app-1")) data = response.get_json() assert len(data["data"]) == 2 diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index e7fc1f8042..a03f09e91e 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect import json from datetime import datetime from types import SimpleNamespace @@ -7,6 +8,7 @@ from typing import cast from unittest.mock import Mock import pytest +from flask import Flask from pydantic import ValidationError from werkzeug.exceptions import HTTPException, NotFound @@ -17,12 +19,6 @@ from graphon.variables import SecretVariable, StringVariable from graphon.variables.variables import RAGPipelineVariable -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _make_workflow(**overrides): workflow = SimpleNamespace( id="workflow-1", @@ -107,24 +103,20 @@ def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: build_mock.assert_called_once() -def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_sync_draft_workflow_invalid_content_type(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = workflow_module.DraftWorkflowApi() - handler = _unwrap(api.post) - - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + handler = inspect.unwrap(api.post) with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"): with pytest.raises(HTTPException) as exc: - handler(api, app_model=SimpleNamespace(id="app")) + handler(api, "t1", app_model=SimpleNamespace(id="app")) assert exc.value.code == 415 -def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_sync_draft_workflow_invalid_json(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: api = workflow_module.DraftWorkflowApi() - handler = _unwrap(api.post) - - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/draft", @@ -132,19 +124,19 @@ def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) data="[]", content_type="application/json", ): - response, status = handler(api, app_model=SimpleNamespace(id="app")) + response, status = handler(api, "t1", app_model=SimpleNamespace(id="app")) assert status == 400 assert response["message"] == "Invalid JSON data" -def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_sync_draft_workflow_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow = SimpleNamespace( unique_hash="h", updated_at=None, created_at=datetime(2024, 1, 1), ) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + monkeypatch.setattr( workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env" ) @@ -156,20 +148,19 @@ def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> No monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service) api = workflow_module.DraftWorkflowApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/draft", method="POST", json={"graph": {}, "features": {}, "hash": "h"}, ): - response = handler(api, app_model=SimpleNamespace(id="app")) + response = handler(api, "t1", app_model=SimpleNamespace(id="app")) assert response["result"] == "success" -def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) +def test_sync_draft_workflow_hash_mismatch(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: def _raise(*_args, **_kwargs): raise workflow_module.WorkflowHashNotEqualError() @@ -178,7 +169,7 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service) api = workflow_module.DraftWorkflowApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/draft", @@ -186,10 +177,10 @@ def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) json={"graph": {}, "features": {}, "hash": "h"}, ): with pytest.raises(DraftWorkflowNotSync): - handler(api, app_model=SimpleNamespace(id="app")) + handler(api, "t1", app_model=SimpleNamespace(id="app")) -def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_restore_published_workflow_to_draft_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: workflow = SimpleNamespace( unique_hash="restored-hash", updated_at=None, @@ -197,7 +188,6 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo ) user = SimpleNamespace(id="account-1") - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) monkeypatch.setattr( workflow_module, "WorkflowService", @@ -205,7 +195,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo ) api = workflow_module.DraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/published-workflow/restore", @@ -213,6 +203,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo ): response = handler( api, + "t1", app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), workflow_id="published-workflow", ) @@ -221,10 +212,7 @@ def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.Mo assert response["hash"] == "restored-hash" -def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: - user = SimpleNamespace(id="account-1") - - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) +def test_restore_published_workflow_to_draft_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", @@ -236,7 +224,7 @@ def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest. ) api = workflow_module.DraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/published-workflow/restore", @@ -245,15 +233,15 @@ def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest. with pytest.raises(NotFound): handler( api, + "t1", app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), workflow_id="published-workflow", ) -def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None: - user = SimpleNamespace(id="account-1") - - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) +def test_restore_published_workflow_to_draft_returns_400_for_draft_source( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", @@ -268,7 +256,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m ) api = workflow_module.DraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/draft-workflow/restore", @@ -277,6 +265,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m with pytest.raises(HTTPException) as exc: handler( api, + "t1", app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), workflow_id="draft-workflow", ) @@ -286,11 +275,8 @@ def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, m def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( - app, monkeypatch: pytest.MonkeyPatch + app: Flask, monkeypatch: pytest.MonkeyPatch ) -> None: - user = SimpleNamespace(id="account-1") - - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) monkeypatch.setattr( workflow_module, "WorkflowService", @@ -302,7 +288,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( ) api = workflow_module.DraftWorkflowRestoreApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/workflows/published-workflow/restore", @@ -311,6 +297,7 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( with pytest.raises(HTTPException) as exc: handler( api, + "t1", app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), workflow_id="published-workflow", ) @@ -319,9 +306,11 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( assert exc.value.description == "invalid workflow graph" -def test_get_published_workflows_serializes_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_get_published_workflows_serializes_items_before_session_closes( + app: Flask, monkeypatch: pytest.MonkeyPatch +) -> None: api = workflow_module.PublishedAllWorkflowApi() - handler = _unwrap(api.get) + handler = inspect.unwrap(api.get) session_state = {"open": False} @@ -351,7 +340,6 @@ def test_get_published_workflows_serializes_items_before_session_closes(app, mon monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) monkeypatch.setattr( workflow_module, "WorkflowService", @@ -365,7 +353,7 @@ def test_get_published_workflows_serializes_items_before_session_closes(app, mon method="GET", query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, ): - response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1")) + response = handler(api, "t1", app_model=SimpleNamespace(id="app", workflow_id="wf-1")) assert response["items"][0]["id"] == "w1" assert response["page"] == 1 @@ -380,7 +368,7 @@ def test_draft_workflow_get_serializes_response_model(monkeypatch: pytest.Monkey ) api = workflow_module.DraftWorkflowApi() - handler = _unwrap(api.get) + handler = inspect.unwrap(api.get) response = handler(api, app_model=SimpleNamespace(id="app")) @@ -522,13 +510,13 @@ def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: ) api = workflow_module.DraftWorkflowApi() - handler = _unwrap(api.get) + handler = inspect.unwrap(api.get) with pytest.raises(DraftWorkflowNotExist): handler(api, app_model=SimpleNamespace(id="app")) -def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_advanced_chat_run_conversation_not_exists(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module.AppGenerateService, "generate", @@ -536,10 +524,9 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk workflow_module.services.errors.conversation.ConversationNotExistsError() ), ) - monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) api = workflow_module.AdvancedChatDraftWorkflowRunApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/app/advanced-chat/workflows/draft/run", @@ -547,10 +534,10 @@ def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.Monk json={"inputs": {}}, ): with pytest.raises(NotFound): - handler(api, app_model=SimpleNamespace(id="app")) + handler(api, "t1", app_model=SimpleNamespace(id="app")) -def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_online_users_filters_inaccessible_workflow(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: app_id_1 = "11111111-1111-1111-1111-111111111111" app_id_2 = "22222222-2222-2222-2222-222222222222" signed_avatar_url = "https://files.example.com/signed/avatar-1" @@ -602,7 +589,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory) api = workflow_module.WorkflowOnlineUsersApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/workflows/online-users", @@ -636,7 +623,7 @@ def test_workflow_online_users_filters_inaccessible_workflow(app, monkeypatch: p sign_avatar.assert_called_once_with("avatar-file-id") -def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_online_users_batches_redis_reads(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: app_ids = [f"wf-{index}" for index in range(workflow_module.WORKFLOW_ONLINE_USERS_REDIS_BATCH_SIZE + 1)] monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) monkeypatch.setattr( @@ -653,7 +640,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk monkeypatch.setattr(workflow_module.redis_client, "pipeline", redis_pipeline_factory) api = workflow_module.WorkflowOnlineUsersApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/workflows/online-users", @@ -668,7 +655,7 @@ def test_workflow_online_users_batches_redis_reads(app, monkeypatch: pytest.Monk assert second_pipeline.hgetall.call_count == 1 -def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: pytest.MonkeyPatch) -> None: +def test_workflow_online_users_rejects_excessive_workflow_ids(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "tenant-1")) accessible_app_ids = Mock(return_value=set()) monkeypatch.setattr( @@ -680,7 +667,7 @@ def test_workflow_online_users_rejects_excessive_workflow_ids(app, monkeypatch: excessive_ids = [f"wf-{index}" for index in range(workflow_module.MAX_WORKFLOW_ONLINE_USERS_REQUEST_IDS + 1)] api = workflow_module.WorkflowOnlineUsersApi() - handler = _unwrap(api.post) + handler = inspect.unwrap(api.post) with app.test_request_context( "/apps/workflows/online-users", diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py index cd0ceee2b1..8c9c9f9d56 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_comment_api.py @@ -99,7 +99,7 @@ class MutationResponseCase: expected_status: int | None = None -def _unwrap_response(result: object) -> tuple[dict[str, object], int | None]: +def unwrap_response(result: object) -> tuple[dict[str, object], int | None]: if isinstance(result, tuple): response, status = result assert isinstance(response, dict) @@ -194,7 +194,7 @@ def test_create_comment_allows_editor(app: Flask, monkeypatch: pytest.MonkeyPatc with _patch_payload(payload): result = workflow_comment_module.WorkflowCommentListApi().post(app_id="app-123") - response, status = _unwrap_response(result) + response, status = unwrap_response(result) assert response["id"] == "comment-1" assert status == 201 create_comment_mock.assert_called_once_with( @@ -224,7 +224,7 @@ def test_update_comment_omits_mentions_when_payload_does_not_include_them( with _patch_payload(payload): result = workflow_comment_module.WorkflowCommentDetailApi().put(app_id="app-123", comment_id="comment-1") - response, status = _unwrap_response(result) + response, status = unwrap_response(result) assert response == {"id": "comment-1", "updated_at": JAN_1_2024_NOON_TS} assert status is None update_comment_mock.assert_called_once_with( @@ -480,7 +480,7 @@ def test_mutation_endpoints_serialize_response_models( with _patch_payload(case.payload): result = getattr(case.resource_cls(), case.method_name)(**case.kwargs) - response, status = _unwrap_response(result) + response, status = unwrap_response(result) assert response == case.expected_response assert status == case.expected_status diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py index c76cb8d5d7..71034ebd40 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_run_api.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from typing import Any @@ -12,12 +13,6 @@ from controllers.console.app import workflow_run as workflow_run_module from models import Account -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _serialize_200_response(handler, payload: Any) -> Any: response_doc = getattr(handler, "__apidoc__", {}).get("responses", {}).get("200") if response_doc is None: @@ -100,7 +95,7 @@ def test_workflow_run_list_returns_frontend_history_contract(app: Flask, monkeyp monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) api = workflow_run_module.WorkflowRunListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/apps/app-1/workflow-runs?limit=10", method="GET"): payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) @@ -141,7 +136,7 @@ def test_advanced_chat_workflow_run_list_keeps_message_fields(app: Flask, monkey monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) api = workflow_run_module.AdvancedChatAppWorkflowRunListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/apps/app-1/advanced-chat/workflow-runs?limit=1", method="GET"): payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) @@ -180,7 +175,7 @@ def test_workflow_run_detail_returns_frontend_detail_contract(app: Flask, monkey monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) api = workflow_run_module.WorkflowRunDetailApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/apps/app-1/workflow-runs/run-1", method="GET"): payload = handler(api, app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1"), run_id="run-1") @@ -217,7 +212,7 @@ def test_workflow_run_node_executions_return_frontend_trace_contract( monkeypatch.setattr(workflow_run_module, "WorkflowRunService", WorkflowRunService) api = workflow_run_module.WorkflowRunNodeExecutionListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/apps/app-1/workflow-runs/run-1/node-executions", method="GET"): payload = handler( diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py index 51bbc33079..7f449bb376 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import PropertyMock, patch @@ -12,12 +13,6 @@ from controllers.console.auth.data_source_bearer_auth import ( ) -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _payload_patch(payload: dict): return patch.object( type(console_ns), @@ -29,7 +24,7 @@ def _payload_patch(payload: dict): def test_list_data_source_auth_uses_injected_tenant_id() -> None: api = ApiKeyAuthDataSource() - method = _unwrap(api.get) + method = unwrap(api.get) binding = SimpleNamespace( id="binding-1", category="api_key", @@ -52,7 +47,7 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None: def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: api = ApiKeyAuthDataSourceBinding() - method = _unwrap(api.post) + method = unwrap(api.post) payload = { "category": "api_key", "provider": "custom", @@ -73,7 +68,7 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None: api = ApiKeyAuthDataSourceBindingDelete() - method = _unwrap(api.delete) + method = unwrap(api.delete) with patch( "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" diff --git a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py index 906688d8c8..92656357d4 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_login_logout.py +++ b/api/tests/unit_tests/controllers/console/auth/test_login_logout.py @@ -41,13 +41,7 @@ def encode_code(code: str) -> str: return base64.b64encode(code.encode("utf-8")).decode() -def _unwrap(func): - bound_self = getattr(func, "__self__", None) - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - if bound_self is not None: - return func.__get__(bound_self, bound_self.__class__) - return func +from inspect import unwrap class TestLoginApi: @@ -510,7 +504,7 @@ class TestLogoutApi: # Act with app.test_request_context("/logout", method="POST"): logout_api = LogoutApi() - response = _unwrap(logout_api.post)(mock_account) + response = unwrap(logout_api.post)(logout_api, mock_account) # Assert mock_service_logout.assert_called_once_with(account=mock_account) @@ -536,7 +530,7 @@ class TestLogoutApi: # Act with app.test_request_context("/logout", method="POST"): logout_api = LogoutApi() - response = _unwrap(logout_api.post)(anonymous_user) + response = unwrap(logout_api.post)(logout_api, anonymous_user) # Assert assert response.json["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py index 1508d7b50e..9372ec8692 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inspect import unwrap from unittest.mock import patch from controllers.console.auth.oauth_server import OAuthServerUserAuthorizeApi @@ -8,12 +9,6 @@ from models.account import AccountStatus, TenantAccountRole from models.model import OAuthProviderApp -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _make_account() -> Account: account = Account( name="Test User", @@ -38,7 +33,7 @@ def _make_oauth_provider_app() -> OAuthProviderApp: def test_oauth_authorize_uses_injected_current_user() -> None: api = OAuthServerUserAuthorizeApi() - method = _unwrap(api.post) + method = unwrap(api.post) account = _make_account() oauth_provider_app = _make_oauth_provider_app() diff --git a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py index 22974ca416..ba69f4d6a7 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py +++ b/api/tests/unit_tests/controllers/console/auth/test_token_refresh.py @@ -49,7 +49,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair): + def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app: Flask, mock_token_pair): """ Test successful token refresh flow. @@ -170,7 +170,7 @@ class TestRefreshTokenApi: @patch("controllers.console.auth.login.extract_refresh_token", autospec=True) @patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True) - def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair): + def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app: Flask, mock_token_pair): """ Test that token refresh updates all three tokens. diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py index 5a3858aa03..5a66bc4e92 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -1,7 +1,6 @@ from __future__ import annotations -from collections.abc import Callable -from typing import Any, cast +from inspect import unwrap from unittest.mock import PropertyMock, patch import pytest @@ -20,12 +19,6 @@ from models.dataset import PipelineCustomizedTemplate from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity -def _unwrap(func: object) -> Callable[..., Any]: - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return cast(Callable[..., Any], func) - - def _template_item() -> dict[str, object]: return { "id": "template-1", @@ -60,7 +53,7 @@ def _payload() -> dict[str, object]: class TestPipelineTemplateListApi: def test_get_uses_query_defaults_and_serializes_nullable_fields(self, app: Flask) -> None: api = PipelineTemplateListApi() - method = _unwrap(api.get) + method = unwrap(api.get) service_calls: list[tuple[str, str]] = [] def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]: @@ -87,7 +80,7 @@ class TestPipelineTemplateListApi: def test_get_passes_explicit_query_to_service(self, app: Flask) -> None: api = PipelineTemplateListApi() - method = _unwrap(api.get) + method = unwrap(api.get) service_calls: list[tuple[str, str]] = [] def get_pipeline_templates(template_type: str, language: str) -> dict[str, object]: @@ -108,7 +101,7 @@ class TestPipelineTemplateListApi: class TestPipelineTemplateDetailApi: def test_get_serializes_template_detail(self, app: Flask) -> None: api = PipelineTemplateDetailApi() - method = _unwrap(api.get) + method = unwrap(api.get) service_calls: list[tuple[str, str]] = [] class Service: @@ -128,7 +121,7 @@ class TestPipelineTemplateDetailApi: def test_get_raises_not_found_without_custom_response_body(self, app: Flask) -> None: api = PipelineTemplateDetailApi() - method = _unwrap(api.get) + method = unwrap(api.get) class Service: def get_pipeline_template_detail(self, template_id: str, template_type: str) -> None: @@ -145,7 +138,7 @@ class TestPipelineTemplateDetailApi: class TestCustomizedPipelineTemplateApi: def test_patch_validates_payload_and_returns_empty_204(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() - method = _unwrap(api.patch) + method = unwrap(api.patch) payload = _payload() service_calls: list[tuple[str, PipelineTemplateInfoEntity]] = [] @@ -174,7 +167,7 @@ class TestCustomizedPipelineTemplateApi: def test_patch_defaults_missing_icon_info_before_service_call(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() - method = _unwrap(api.patch) + method = unwrap(api.patch) payload: dict[str, object] = { "name": "Updated template", "description": "Updated description", @@ -204,7 +197,7 @@ class TestCustomizedPipelineTemplateApi: def test_delete_returns_empty_204(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() - method = _unwrap(api.delete) + method = unwrap(api.delete) deleted_template_ids: list[str] = [] def delete_template(template_id: str) -> None: @@ -221,7 +214,7 @@ class TestCustomizedPipelineTemplateApi: def test_post_exports_yaml_from_orm_template(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() - method = _unwrap(api.post) + method = unwrap(api.post) template = PipelineCustomizedTemplate( tenant_id="00000000-0000-0000-0000-000000000001", name="Template", @@ -265,7 +258,7 @@ class TestCustomizedPipelineTemplateApi: def test_post_raises_when_template_is_missing(self, app: Flask) -> None: api = CustomizedPipelineTemplateApi() - method = _unwrap(api.post) + method = unwrap(api.post) class Session: def scalar(self, stmt: object) -> None: @@ -297,7 +290,7 @@ class TestCustomizedPipelineTemplateApi: class TestPublishCustomizedPipelineTemplateApi: def test_post_validates_payload_and_returns_empty_204(self, app: Flask) -> None: api = PublishCustomizedPipelineTemplateApi() - method = _unwrap(api.post) + method = unwrap(api.post) payload = _payload() service_calls: list[tuple[str, dict[str, object]]] = [] @@ -317,7 +310,7 @@ class TestPublishCustomizedPipelineTemplateApi: def test_post_allows_missing_icon_info_for_publish_service_fallback(self, app: Flask) -> None: api = PublishCustomizedPipelineTemplateApi() - method = _unwrap(api.post) + method = unwrap(api.post) payload: dict[str, object] = { "name": "Published template", "description": "Description", diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 322f1baa96..52e36fd521 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import PropertyMock, patch @@ -9,12 +10,6 @@ import pytest from controllers.console.datasets.rag_pipeline import rag_pipeline_workflow as module -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _make_workflow(**overrides): workflow = SimpleNamespace( id="workflow-1", @@ -45,7 +40,7 @@ def test_draft_rag_pipeline_workflow_get_serializes_response_model(monkeypatch: ) api = module.DraftRagPipelineApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) response = handler(api, pipeline=SimpleNamespace(id="pipeline-1")) @@ -63,7 +58,7 @@ def test_published_rag_pipeline_workflows_serialize_items_before_session_closes( app, monkeypatch: pytest.MonkeyPatch ) -> None: api = module.PublishedAllRagPipelineApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) session_state = {"open": False} class _SessionContext: @@ -133,7 +128,7 @@ def test_rag_pipeline_workflow_patch_serializes_response_model(app, monkeypatch: payload: dict[str, object] = {"marked_name": "Updated release"} api = module.RagPipelineByIdApi() - handler = _unwrap(api.patch) + handler = unwrap(api.patch) with ( app.test_request_context("/rag/pipelines/pipeline-1/workflows/workflow-1", method="PATCH", json=payload), diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py index 80a688ab0e..956b034673 100644 --- a/api/tests/unit_tests/controllers/console/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -2,6 +2,7 @@ from __future__ import annotations import json from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -23,12 +24,6 @@ from models.human_input import RecipientType from models.model import AppMode -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def test_jsonify_form_definition() -> None: expiration = datetime(2024, 1, 1, tzinfo=UTC) definition = SimpleNamespace(model_dump=lambda: {"fields": []}) @@ -64,7 +59,7 @@ def test_get_form_definition_success(app: Flask, monkeypatch: pytest.MonkeyPatch monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/console/api/form/human_input/token", method="GET"): response = handler(api, "tenant-1", form_token="token") @@ -85,7 +80,7 @@ def test_get_form_definition_not_found(app: Flask, monkeypatch: pytest.MonkeyPat monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/console/api/form/human_input/token", method="GET"): with pytest.raises(NotFoundError): @@ -106,7 +101,7 @@ def test_post_form_invalid_recipient_type(app: Flask, monkeypatch: pytest.Monkey monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/console/api/form/human_input/token", @@ -137,7 +132,7 @@ def test_post_form_rejects_webapp_recipient_type(app: Flask, monkeypatch: pytest monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/console/api/form/human_input/token", @@ -172,7 +167,7 @@ def test_post_form_success(app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleHumanInputFormApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context( "/console/api/form/human_input/token", @@ -244,7 +239,7 @@ def test_workflow_events_not_found(app: Flask, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/console/api/workflow/run/events", method="GET"): with pytest.raises(NotFoundError): @@ -271,7 +266,7 @@ def test_workflow_events_requires_account(app: Flask, monkeypatch: pytest.Monkey monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/console/api/workflow/run/events", method="GET"): with pytest.raises(NotFoundError): @@ -298,7 +293,7 @@ def test_workflow_events_requires_creator(app: Flask, monkeypatch: pytest.Monkey monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/console/api/workflow/run/events", method="GET"): with pytest.raises(NotFoundError): @@ -342,7 +337,7 @@ def test_workflow_events_finished(app: Flask, monkeypatch: pytest.MonkeyPatch) - monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) api = ConsoleWorkflowEventsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) with app.test_request_context("/console/api/workflow/run/events", method="GET"): response = handler(api, "t1", SimpleNamespace(id="user-1"), workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py index 206efec4c6..e7127aef23 100644 --- a/api/tests/unit_tests/controllers/console/test_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -2,6 +2,7 @@ from __future__ import annotations import urllib.parse from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import MagicMock @@ -16,12 +17,6 @@ from services.errors.file import FileTooLargeError as ServiceFileTooLargeError from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _make_account(account_id: str = "u1") -> Account: account = Account( name="Test User", @@ -89,7 +84,7 @@ def _mock_upload_dependencies( def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.GetRemoteFileInfo() - handler = _unwrap(api.get) + handler = unwrap(api.get) decoded_url = "https://example.com/test.txt" encoded_url = urllib.parse.quote(decoded_url, safe="") @@ -110,7 +105,7 @@ def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.GetRemoteFileInfo() - handler = _unwrap(api.get) + handler = unwrap(api.get) target_url = "http://example.com/api/aiagent/httpview/txt" query = "fileNameKey=cankao1_ce4305bc-be20-4c5d-8732-de1741d28e27" @@ -131,7 +126,7 @@ def test_get_remote_file_info_preserves_unencoded_target_query(app, monkeypatch: def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.GetRemoteFileInfo() - handler = _unwrap(api.get) + handler = unwrap(api.get) decoded_url = "https://example.com/test.txt" encoded_url = urllib.parse.quote(decoded_url, safe="") @@ -154,7 +149,7 @@ def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, mo def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/report.txt" get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content") @@ -196,7 +191,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( app, monkeypatch: pytest.MonkeyPatch ) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/photo.jpg" head_resp = _FakeResponse(status_code=200, method="HEAD", content=b"head-content") @@ -227,7 +222,7 @@ def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/fail.txt" make_request = MagicMock( @@ -245,7 +240,7 @@ def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypat def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/fail.txt" request = httpx.Request("HEAD", url) @@ -259,7 +254,7 @@ def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pyte def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/large.bin" make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")) @@ -274,7 +269,7 @@ def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.Monk def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/large.bin" make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")) @@ -289,7 +284,7 @@ def test_remote_file_upload_translates_service_file_too_large_error(app, monkeyp def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None: api = remote_files_module.RemoteFileUpload() - handler = _unwrap(api.post) + handler = unwrap(api.post) url = "https://example.com/file.exe" make_request = MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")) diff --git a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py index 52fd0f89d5..da4289bdde 100644 --- a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py @@ -10,6 +10,7 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from flask import Flask from werkzeug.exceptions import NotFound from controllers.openapi.auth.data import AuthData @@ -30,7 +31,7 @@ def _make_auth_data(app_model, caller, caller_kind): class TestOpenApiHumanInputFormGet: - def test_get_success(self, app, bypass_pipeline, monkeypatch): + def test_get_success(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi definition = SimpleNamespace( @@ -74,7 +75,7 @@ class TestOpenApiHumanInputFormGet: assert payload["user_actions"] == [{"id": "submit", "title": "Submit"}] service_mock.ensure_form_active.assert_called_once_with(form) - def test_get_form_not_found(self, app, bypass_pipeline, monkeypatch): + def test_get_form_not_found(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi service_mock = Mock() @@ -96,7 +97,7 @@ class TestOpenApiHumanInputFormGet: auth_data=_make_auth_data(app_model, caller, "account"), ) - def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch): + def test_get_form_wrong_app(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi form = SimpleNamespace( @@ -121,7 +122,7 @@ class TestOpenApiHumanInputFormGet: auth_data=_make_auth_data(app_model, caller, "account"), ) - def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch): + def test_get_form_wrong_surface(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi form = SimpleNamespace( @@ -159,7 +160,7 @@ class TestOpenApiHumanInputFormPost: expiration_time=datetime(2099, 1, 1, tzinfo=UTC), ) - def test_post_account_caller_uses_user_id(self, app, bypass_pipeline, monkeypatch): + def test_post_account_caller_uses_user_id(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi form = self._make_form() @@ -196,7 +197,7 @@ class TestOpenApiHumanInputFormPost: ) assert result == ({}, 200) - def test_post_end_user_caller_uses_end_user_id(self, app, bypass_pipeline, monkeypatch): + def test_post_end_user_caller_uses_end_user_id(self, app: Flask, bypass_pipeline, monkeypatch: pytest.MonkeyPatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi form = self._make_form() diff --git a/api/tests/unit_tests/controllers/service_api/app/test_annotation.py b/api/tests/unit_tests/controllers/service_api/app/test_annotation.py index 6d586d31a9..b4dd5e957c 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_annotation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_annotation.py @@ -13,6 +13,7 @@ Note: API endpoint tests for annotation controllers are complex due to: """ import uuid +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -34,13 +35,6 @@ from extensions.ext_redis import redis_client from models.model import App from services.annotation_service import AppAnnotationService - -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - # --------------------------------------------------------------------------- # Pydantic Model Tests # --------------------------------------------------------------------------- @@ -193,7 +187,7 @@ class TestAnnotationReplyActionApi: monkeypatch.setattr(AppAnnotationService, "enable_app_annotation", enable_mock) api = AnnotationReplyActionApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="app") with app.test_request_context( @@ -211,7 +205,7 @@ class TestAnnotationReplyActionApi: monkeypatch.setattr(AppAnnotationService, "disable_app_annotation", disable_mock) api = AnnotationReplyActionApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="app") with app.test_request_context( @@ -230,7 +224,7 @@ class TestAnnotationReplyActionStatusApi: monkeypatch.setattr(redis_client, "get", lambda *_args, **_kwargs: None) api = AnnotationReplyActionStatusApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app") with pytest.raises(ValueError): @@ -245,7 +239,7 @@ class TestAnnotationReplyActionStatusApi: monkeypatch.setattr(redis_client, "get", _get) api = AnnotationReplyActionStatusApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app") response, status = handler(api, app_model=app_model, job_id="j1", action="enable") @@ -262,7 +256,7 @@ class TestAnnotationListApi: monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock) api = AnnotationListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app") with app.test_request_context("/apps/annotations", method="GET"): @@ -278,7 +272,7 @@ class TestAnnotationListApi: monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock) api = AnnotationListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app") with app.test_request_context("/apps/annotations?page=2&limit=5&keyword=refund", method="GET"): @@ -297,7 +291,7 @@ class TestAnnotationListApi: monkeypatch.setattr(AppAnnotationService, "get_annotation_list_by_app_id", get_mock) api = AnnotationListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app") with app.test_request_context(f"/apps/annotations?{query_string}", method="GET"): @@ -315,7 +309,7 @@ class TestAnnotationListApi: ) api = AnnotationListApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="app") with app.test_request_context("/apps/annotations", method="POST", json={"question": "q", "answer": "a"}): @@ -337,8 +331,8 @@ class TestAnnotationUpdateDeleteApi: monkeypatch.setattr(AppAnnotationService, "delete_app_annotation", delete_mock) api = AnnotationUpdateDeleteApi() - put_handler = _unwrap(api.put) - delete_handler = _unwrap(api.delete) + put_handler = unwrap(api.put) + delete_handler = unwrap(api.delete) app_model = SimpleNamespace(id="app") with app.test_request_context("/apps/annotations/1", method="PUT", json={"question": "q", "answer": "a"}): diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 4741481ef6..1cfe152c86 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -9,6 +9,7 @@ Tests coverage for: import io import uuid +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock, patch @@ -41,12 +42,6 @@ from services.errors.audio import ( ) -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - def _file_data(): return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav") @@ -194,7 +189,7 @@ class TestAudioApi: def test_success(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) api = AudioApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") end_user = SimpleNamespace(id="u1") @@ -220,7 +215,7 @@ class TestAudioApi: def test_error_mapping(self, app: Flask, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) api = AudioApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") end_user = SimpleNamespace(id="u1") @@ -233,7 +228,7 @@ class TestAudioApi: AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom")) ) api = AudioApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") end_user = SimpleNamespace(id="u1") @@ -247,7 +242,7 @@ class TestTextApi: monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) api = TextApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") end_user = SimpleNamespace(external_user_id="ext") @@ -266,7 +261,7 @@ class TestTextApi: ) api = TextApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="a1") end_user = SimpleNamespace(external_user_id="ext") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 745df9c798..46ce1a85ca 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -12,6 +12,7 @@ Focus on: """ import uuid +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock, patch @@ -44,12 +45,6 @@ from services.errors.conversation import ConversationNotExistsError from services.errors.llm import InvokeRateLimitError -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - class TestCompletionRequestPayload: """Test suite for CompletionRequestPayload Pydantic model.""" @@ -426,7 +421,7 @@ class TestChatRequestPayloadController: class TestCompletionApiController: def test_wrong_mode(self, app: Flask) -> None: api = CompletionApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -444,7 +439,7 @@ class TestCompletionApiController: end_user = SimpleNamespace() api = CompletionApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) with app.test_request_context("/completion-messages", method="POST", json={"inputs": {}}): with pytest.raises(NotFound): @@ -454,7 +449,7 @@ class TestCompletionApiController: class TestCompletionStopApiController: def test_wrong_mode(self, app: Flask) -> None: api = CompletionStopApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace(id="u1") @@ -467,7 +462,7 @@ class TestCompletionStopApiController: monkeypatch.setattr(AppTaskService, "stop_task", stop_mock) api = CompletionStopApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.COMPLETION) end_user = SimpleNamespace(id="u1") @@ -481,7 +476,7 @@ class TestCompletionStopApiController: class TestChatApiController: def test_wrong_mode(self, app: Flask) -> None: api = ChatApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.COMPLETION.value) end_user = SimpleNamespace() @@ -497,7 +492,7 @@ class TestChatApiController: ) api = ChatApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -513,7 +508,7 @@ class TestChatApiController: ) api = ChatApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -525,7 +520,7 @@ class TestChatApiController: class TestChatStopApiController: def test_wrong_mode(self, app: Flask) -> None: api = ChatStopApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.COMPLETION) end_user = SimpleNamespace(id="u1") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py index abb476a750..97873c631a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_conversation.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_conversation.py @@ -16,6 +16,7 @@ Focus on: import sys import uuid from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock, patch @@ -51,12 +52,6 @@ from services.errors.conversation import ( ) -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - class TestConversationListQuery: """Test suite for ConversationListQuery Pydantic model.""" @@ -380,7 +375,7 @@ class TestConversationAppModeValidation: app raises NotChatAppError. """ app = Mock(spec=App) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW app_mode = AppMode.value_of(app.mode) assert app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT} @@ -498,7 +493,7 @@ class TestConversationPayloadsController: class TestConversationApiController: def test_list_not_chat(self, app: Flask) -> None: api = ConversationApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.COMPLETION) end_user = SimpleNamespace() @@ -531,7 +526,7 @@ class TestConversationApiController: monkeypatch.setattr(conversation_module, "sessionmaker", _SessionMakerStub) api = ConversationApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -546,7 +541,7 @@ class TestConversationApiController: class TestConversationDetailApiController: def test_delete_not_chat(self, app: Flask) -> None: api = ConversationDetailApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) app_model = SimpleNamespace(mode=AppMode.COMPLETION) end_user = SimpleNamespace() @@ -562,7 +557,7 @@ class TestConversationDetailApiController: ) api = ConversationDetailApi() - handler = _unwrap(api.delete) + handler = unwrap(api.delete) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -580,7 +575,7 @@ class TestConversationRenameApiController: ) api = ConversationRenameApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -596,7 +591,7 @@ class TestConversationRenameApiController: class TestConversationVariablesApiController: def test_not_chat(self, app: Flask) -> None: api = ConversationVariablesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.COMPLETION) end_user = SimpleNamespace() @@ -612,7 +607,7 @@ class TestConversationVariablesApiController: ) api = ConversationVariablesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -645,7 +640,7 @@ class TestConversationVariablesApiController: ) api = ConversationVariablesApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -671,7 +666,7 @@ class TestConversationVariableDetailApiController: ) api = ConversationVariableDetailApi() - handler = _unwrap(api.put) + handler = unwrap(api.put) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -697,7 +692,7 @@ class TestConversationVariableDetailApiController: ) api = ConversationVariableDetailApi() - handler = _unwrap(api.put) + handler = unwrap(api.put) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() @@ -731,7 +726,7 @@ class TestConversationVariableDetailApiController: ) api = ConversationVariableDetailApi() - handler = _unwrap(api.put) + handler = unwrap(api.put) app_model = SimpleNamespace(mode=AppMode.CHAT) end_user = SimpleNamespace() diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file.py b/api/tests/unit_tests/controllers/service_api/app/test_file.py index 88ebe955a8..e44f6cd06c 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file.py @@ -203,7 +203,7 @@ class TestFileUploadResponse: # unwrapped method directly to bypass the decorator. # ============================================================================= -from tests.unit_tests.controllers.service_api.conftest import _unwrap +from inspect import unwrap @pytest.fixture @@ -274,7 +274,7 @@ class TestFileApiPost: data=data, ): api = FileApi() - response, status = _unwrap(api.post)( + response, status = unwrap(api.post)( api, app_model=mock_app_model, end_user=mock_end_user, @@ -295,7 +295,7 @@ class TestFileApiPost: ): api = FileApi() with pytest.raises(NoFileUploadedError): - _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) + unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) def test_upload_too_many_files(self, app: Flask, mock_app_model, mock_end_user): """Test TooManyFilesError when multiple files uploaded.""" @@ -316,7 +316,7 @@ class TestFileApiPost: ): api = FileApi() with pytest.raises(TooManyFilesError): - _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) + unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) def test_upload_no_mimetype(self, app: Flask, mock_app_model, mock_end_user): """Test UnsupportedFileTypeError when file has no mimetype.""" @@ -334,7 +334,7 @@ class TestFileApiPost: ): api = FileApi() with pytest.raises(UnsupportedFileTypeError): - _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) + unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) @patch("controllers.service_api.app.file.FileService") @patch("controllers.service_api.app.file.db") @@ -366,7 +366,7 @@ class TestFileApiPost: ): api = FileApi() with pytest.raises(FileTooLargeError): - _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) + unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) @patch("controllers.service_api.app.file.FileService") @patch("controllers.service_api.app.file.db") @@ -396,4 +396,4 @@ class TestFileApiPost: ): api = FileApi() with pytest.raises(UnsupportedFileTypeError): - _unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) + unwrap(api.post)(api, app_model=mock_app_model, end_user=mock_end_user) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py index de52c62fdd..8686f49a4a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_hitl_service_api.py @@ -7,6 +7,7 @@ import sys from collections.abc import Sequence from dataclasses import dataclass from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from typing import override from unittest.mock import ANY, MagicMock, Mock @@ -44,7 +45,6 @@ from repositories.api_workflow_node_execution_repository import WorkflowNodeExec from repositories.entities.workflow_pause import WorkflowPauseEntity from services.app_generate_service import AppGenerateService from services.workflow_event_snapshot_service import _build_snapshot_events -from tests.unit_tests.controllers.service_api.conftest import _unwrap class _DummyRateLimit: @@ -275,8 +275,8 @@ class TestHitlServiceApi: monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context("/workflow/run-1/events?user=u1&continue_on_pause=true", method="GET"): @@ -310,8 +310,8 @@ class TestHitlServiceApi: monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py index ce000ab5a2..0f47f0d630 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_human_input_form.py @@ -5,6 +5,7 @@ from __future__ import annotations import json import sys from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -15,7 +16,6 @@ from werkzeug.exceptions import NotFound from controllers.common.human_input import HumanInputFormSubmitPayload from controllers.service_api.app.human_input_form import WorkflowHumanInputFormApi from models.human_input import RecipientType -from tests.unit_tests.controllers.service_api.conftest import _unwrap class TestWorkflowHumanInputFormApi: @@ -45,7 +45,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") with app.test_request_context("/form/human_input/token-1", method="GET"): @@ -98,7 +98,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") with app.test_request_context("/form/human_input/token-1", method="GET"): @@ -121,7 +121,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") with app.test_request_context("/form/human_input/token-1", method="GET"): @@ -153,7 +153,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") with app.test_request_context("/form/human_input/token-1", method="GET"): @@ -175,7 +175,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") end_user = SimpleNamespace(id="end-user-1") @@ -209,7 +209,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") end_user = SimpleNamespace(id="end-user-1") inputs = { @@ -285,7 +285,7 @@ class TestWorkflowHumanInputFormApi: monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) api = WorkflowHumanInputFormApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") end_user = SimpleNamespace(id="end-user-1") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index 1fda5ce9cf..d8d5c61bcb 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -15,6 +15,7 @@ Focus on: """ import uuid +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock, patch @@ -43,12 +44,6 @@ from services.errors.message import ( from services.message_service import MessageService -def _unwrap(func): - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - return func - - class TestMessageListQuery: """Test suite for MessageListQuery Pydantic model.""" @@ -383,7 +378,7 @@ class TestMessageService: class TestMessageListApi: def test_not_chat_app(self, app: Flask) -> None: api = MessageListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.COMPLETION.value) end_user = SimpleNamespace() @@ -399,7 +394,7 @@ class TestMessageListApi: ) api = MessageListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -418,7 +413,7 @@ class TestMessageListApi: ) api = MessageListApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -439,7 +434,7 @@ class TestMessageFeedbackApi: ) api = MessageFeedbackApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace() end_user = SimpleNamespace() @@ -457,7 +452,7 @@ class TestAppGetFeedbacksApi: monkeypatch.setattr(MessageService, "get_all_messages_feedbacks", lambda *_args, **_kwargs: ["f1"]) api = AppGetFeedbacksApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace() with app.test_request_context("/app/feedbacks?page=1&limit=20", method="GET"): @@ -469,7 +464,7 @@ class TestAppGetFeedbacksApi: class TestMessageSuggestedApi: def test_not_chat(self, app: Flask) -> None: api = MessageSuggestedApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.COMPLETION.value) end_user = SimpleNamespace() @@ -485,7 +480,7 @@ class TestMessageSuggestedApi: ) api = MessageSuggestedApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -501,7 +496,7 @@ class TestMessageSuggestedApi: ) api = MessageSuggestedApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -517,7 +512,7 @@ class TestMessageSuggestedApi: ) api = MessageSuggestedApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -533,7 +528,7 @@ class TestMessageSuggestedApi: ) api = MessageSuggestedApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 7115ea1e12..4f88ae69c2 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -16,6 +16,7 @@ Focus on: import sys import uuid from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock, patch @@ -369,7 +370,7 @@ class TestWorkflowRunRepository: class TestWorkflowRunDetailApi: def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunDetailApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) with app.test_request_context("/workflows/run/1", method="GET"): @@ -388,8 +389,8 @@ class TestWorkflowRunDetailApi: ) api = WorkflowRunDetailApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value, tenant_id="t1", id="a1") + handler = unwrap(api.get) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW, tenant_id="t1", id="a1") result = handler(api, app_model=app_model, workflow_run_id="run") assert result["id"] == "run" @@ -400,7 +401,7 @@ class TestWorkflowRunDetailApi: class TestWorkflowRunApi: def test_not_workflow_app(self, app: Flask) -> None: api = WorkflowRunApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -416,8 +417,8 @@ class TestWorkflowRunApi: ) api = WorkflowRunApi() - handler = _unwrap(api.post) - app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value) + handler = unwrap(api.post) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW) end_user = SimpleNamespace() with app.test_request_context("/workflows/run", method="POST", json={"inputs": {}}): @@ -434,8 +435,8 @@ class TestWorkflowRunByIdApi: ) api = WorkflowRunByIdApi() - handler = _unwrap(api.post) - app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value) + handler = unwrap(api.post) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW) end_user = SimpleNamespace() with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}): @@ -450,8 +451,8 @@ class TestWorkflowRunByIdApi: ) api = WorkflowRunByIdApi() - handler = _unwrap(api.post) - app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value) + handler = unwrap(api.post) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW) end_user = SimpleNamespace() with app.test_request_context("/workflows/1/run", method="POST", json={"inputs": {}}): @@ -462,7 +463,7 @@ class TestWorkflowRunByIdApi: class TestWorkflowTaskStopApi: def test_wrong_mode(self, app: Flask) -> None: api = WorkflowTaskStopApi() - handler = _unwrap(api.post) + handler = unwrap(api.post) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace() @@ -477,8 +478,8 @@ class TestWorkflowTaskStopApi: monkeypatch.setattr(GraphEngineManager, "send_stop_command", send_mock) api = WorkflowTaskStopApi() - handler = _unwrap(api.post) - app_model = SimpleNamespace(mode=AppMode.WORKFLOW.value) + handler = unwrap(api.post) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="u1") with app.test_request_context("/workflows/tasks/1/stop", method="POST"): @@ -515,7 +516,7 @@ class TestWorkflowAppLogApi: ) api = WorkflowAppLogApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(id="a1") with app.test_request_context("/workflows/logs", method="GET"): @@ -533,15 +534,13 @@ class TestWorkflowAppLogApi: # directly to bypass the decorator. # ============================================================================= -from tests.unit_tests.controllers.service_api.conftest import _unwrap - @pytest.fixture def mock_workflow_app(): app = Mock(spec=App) app.id = str(uuid.uuid4()) app.tenant_id = str(uuid.uuid4()) - app.mode = AppMode.WORKFLOW.value + app.mode = AppMode.WORKFLOW return app @@ -574,7 +573,7 @@ class TestWorkflowRunDetailApiGet: method="GET", ): api = WorkflowRunDetailApi() - result = _unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) + result = unwrap(api.get)(api, app_model=mock_workflow_app, workflow_run_id=mock_run.id) assert result["id"] == mock_run.id assert result["status"] == "succeeded" @@ -590,7 +589,7 @@ class TestWorkflowRunDetailApiGet: with app.test_request_context("/workflows/run/run-1", method="GET"): api = WorkflowRunDetailApi() with pytest.raises(NotWorkflowAppError): - _unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1") + unwrap(api.get)(api, app_model=mock_app, workflow_run_id="run-1") class TestWorkflowTaskStopApiPost: @@ -613,7 +612,7 @@ class TestWorkflowTaskStopApiPost: with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): api = WorkflowTaskStopApi() - result = _unwrap(api.post)( + result = unwrap(api.post)( api, app_model=mock_workflow_app, end_user=Mock(), @@ -635,7 +634,7 @@ class TestWorkflowTaskStopApiPost: with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): api = WorkflowTaskStopApi() with pytest.raises(NotWorkflowAppError): - _unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1") + unwrap(api.post)(api, app_model=mock_app, end_user=Mock(), task_id="task-1") class TestWorkflowAppLogApiGet: @@ -681,6 +680,6 @@ class TestWorkflowAppLogApiGet: ): with patch("controllers.service_api.app.workflow.sessionmaker", return_value=mock_session_factory): api = WorkflowAppLogApi() - result = _unwrap(api.get)(api, app_model=mock_workflow_app) + result = unwrap(api.get)(api, app_model=mock_workflow_app) assert result == {"page": 1, "limit": 20, "total": 0, "has_more": False, "data": []} diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py index a1aca06570..94b7c8bca1 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_events.py @@ -5,6 +5,7 @@ from __future__ import annotations import json import sys from datetime import UTC, datetime +from inspect import unwrap from types import SimpleNamespace from unittest.mock import Mock @@ -16,7 +17,6 @@ from controllers.service_api.app.error import NotWorkflowAppError from controllers.service_api.app.workflow_events import WorkflowEventsApi from models.enums import CreatorUserRole from models.model import AppMode -from tests.unit_tests.controllers.service_api.conftest import _unwrap def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): @@ -34,7 +34,7 @@ def _mock_repo_for_run(monkeypatch: pytest.MonkeyPatch, workflow_run): class TestWorkflowEventsApi: def test_wrong_app_mode(self, app: Flask) -> None: api = WorkflowEventsApi() - handler = _unwrap(api.get) + handler = unwrap(api.get) app_model = SimpleNamespace(mode=AppMode.CHAT.value) end_user = SimpleNamespace(id="end-user-1") @@ -45,8 +45,8 @@ class TestWorkflowEventsApi: def test_workflow_run_not_found(self, app: Flask, monkeypatch: pytest.MonkeyPatch) -> None: _mock_repo_for_run(monkeypatch, workflow_run=None) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): @@ -63,8 +63,8 @@ class TestWorkflowEventsApi: ) _mock_repo_for_run(monkeypatch, workflow_run=workflow_run) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): @@ -90,8 +90,8 @@ class TestWorkflowEventsApi: ) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): @@ -121,8 +121,8 @@ class TestWorkflowEventsApi: monkeypatch.setattr(workflow_events_module, "WorkflowAppGenerator", lambda: workflow_generator) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context("/workflow/run-1/events?user=u1", method="GET"): @@ -154,8 +154,8 @@ class TestWorkflowEventsApi: monkeypatch.setattr(workflow_events_module, "build_workflow_event_stream", snapshot_builder) api = WorkflowEventsApi() - handler = _unwrap(api.get) - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + handler = unwrap(api.get) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) end_user = SimpleNamespace(id="end-user-1") with app.test_request_context("/workflow/run-1/events?user=u1&include_state_snapshot=true", method="GET"): diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 8c89812cb4..fff64efd4c 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -204,11 +204,3 @@ def mock_child_chunk(): child_chunk.tenant_id = str(uuid.uuid4()) child_chunk.content = "Test child chunk content" return child_chunk - - -def _unwrap(method): - """Walk ``__wrapped__`` chain to get the original function.""" - fn = method - while hasattr(fn, "__wrapped__"): - fn = fn.__wrapped__ - return fn diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py index 5db87df0a2..7a9978e742 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_metadata.py @@ -16,6 +16,7 @@ Decorator strategy: """ import uuid +from inspect import unwrap from unittest.mock import Mock, patch import pytest @@ -29,7 +30,6 @@ from controllers.service_api.dataset.metadata import ( DatasetMetadataServiceApi, DocumentMetadataEditServiceApi, ) -from tests.unit_tests.controllers.service_api.conftest import _unwrap @pytest.fixture @@ -65,7 +65,7 @@ class TestDatasetMetadataCreatePost: @staticmethod def _call_post(api, **kwargs): - return _unwrap(api.post)(api, **kwargs) + return unwrap(api.post)(api, **kwargs) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -195,7 +195,7 @@ class TestDatasetMetadataServiceApiPatch: @staticmethod def _call_patch(api, **kwargs): - return _unwrap(api.patch)(api, **kwargs) + return unwrap(api.patch)(api, **kwargs) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -267,7 +267,7 @@ class TestDatasetMetadataServiceApiDelete: @staticmethod def _call_delete(api, **kwargs): - return _unwrap(api.delete)(api, **kwargs) + return unwrap(api.delete)(api, **kwargs) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -376,7 +376,7 @@ class TestDatasetMetadataBuiltInFieldAction: @staticmethod def _call_post(api, **kwargs): - return _unwrap(api.post)(api, **kwargs) + return unwrap(api.post)(api, **kwargs) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") @@ -479,7 +479,7 @@ class TestDocumentMetadataEditPost: @staticmethod def _call_post(api, **kwargs): - return _unwrap(api.post)(api, **kwargs) + return unwrap(api.post)(api, **kwargs) @patch("controllers.service_api.dataset.metadata.MetadataService") @patch("controllers.service_api.dataset.metadata.DatasetService") diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py index f8dd6bf609..e62f4584d6 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py @@ -7,7 +7,7 @@ from models.model import AppMode class TestWorkflowAppConfigManager: def test_get_app_config(self): - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) workflow = SimpleNamespace(id="wf-1", features_dict={}) with ( diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index cb494ab8db..1199fc773f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -383,7 +383,7 @@ class TestWorkflowService: assert result == mock_workflow - def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service): + def test_get_draft_workflow_with_workflow_id_reuses_provided_session(self, workflow_service: WorkflowService): """Test get_draft_workflow passes an injected session to published workflow lookup.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() workflow_id = "workflow-123" @@ -458,7 +458,7 @@ class TestWorkflowService: assert result == mock_workflow - def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service): + def test_get_published_workflow_returns_none_when_no_workflow_id(self, workflow_service: WorkflowService): """Test get_published_workflow returns None when app has no workflow_id.""" app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None) @@ -658,21 +658,21 @@ class TestWorkflowService: # ==================== Workflow Validation Tests ==================== # These tests verify graph structure and feature configuration validation - def test_validate_graph_structure_empty_graph(self, workflow_service): + def test_validate_graph_structure_empty_graph(self, workflow_service: WorkflowService): """Test validate_graph_structure accepts empty graph.""" graph = {"nodes": []} # Should not raise any exception workflow_service.validate_graph_structure(graph) - def test_validate_graph_structure_valid_graph(self, workflow_service): + def test_validate_graph_structure_valid_graph(self, workflow_service: WorkflowService): """Test validate_graph_structure accepts valid graph.""" graph = TestWorkflowAssociatedDataFactory.create_valid_workflow_graph() # Should not raise any exception workflow_service.validate_graph_structure(graph) - def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service): + def test_validate_graph_structure_start_and_trigger_coexist_raises_error(self, workflow_service: WorkflowService): """ Test validate_graph_structure raises error when start and trigger nodes coexist. @@ -707,7 +707,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): workflow_service.validate_graph_structure(graph) - def test_validate_features_structure_workflow_mode(self, workflow_service): + def test_validate_features_structure_workflow_mode(self, workflow_service: WorkflowService): """ Test validate_features_structure for workflow mode. @@ -723,7 +723,7 @@ class TestWorkflowService: tenant_id=app.tenant_id, config=features, only_structure_validate=True ) - def test_validate_features_structure_advanced_chat_mode(self, workflow_service): + def test_validate_features_structure_advanced_chat_mode(self, workflow_service: WorkflowService): """Test validate_features_structure for advanced chat mode.""" app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.ADVANCED_CHAT) features = {"opening_statement": "Hello"} @@ -734,7 +734,7 @@ class TestWorkflowService: tenant_id=app.tenant_id, config=features, only_structure_validate=True ) - def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service): + def test_validate_features_structure_invalid_mode_raises_error(self, workflow_service: WorkflowService): """Test validate_features_structure raises error for invalid mode.""" app = TestWorkflowAssociatedDataFactory.create_app_mock(mode=AppMode.COMPLETION) features = {} @@ -767,7 +767,7 @@ class TestWorkflowService: assert workflow.updated_at == "now" mock_db_session.session.commit.assert_called_once() - def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service): + def test_update_draft_workflow_environment_variables_raises_when_missing(self, workflow_service: WorkflowService): """Test update_draft_workflow_environment_variables raises when draft missing.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() account = TestWorkflowAssociatedDataFactory.create_account_mock() @@ -802,7 +802,7 @@ class TestWorkflowService: assert workflow.updated_at == "now" mock_db_session.session.commit.assert_called_once() - def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service): + def test_update_draft_workflow_conversation_variables_raises_when_missing(self, workflow_service: WorkflowService): """Test update_draft_workflow_conversation_variables raises when draft missing.""" app = TestWorkflowAssociatedDataFactory.create_app_mock() account = TestWorkflowAssociatedDataFactory.create_account_mock() @@ -917,7 +917,7 @@ class TestWorkflowService: # ==================== Version Management Tests ==================== # These tests verify listing and managing published workflow versions - def test_get_all_published_workflow_with_pagination(self, workflow_service): + def test_get_all_published_workflow_with_pagination(self, workflow_service: WorkflowService): """ Test get_all_published_workflow returns paginated results. @@ -950,7 +950,7 @@ class TestWorkflowService: assert len(workflows) == 5 assert has_more is False - def test_get_all_published_workflow_has_more(self, workflow_service): + def test_get_all_published_workflow_has_more(self, workflow_service: WorkflowService): """ Test get_all_published_workflow indicates has_more when results exceed limit. @@ -983,7 +983,7 @@ class TestWorkflowService: assert len(workflows) == 10 assert has_more is True - def test_get_all_published_workflow_no_workflow_id(self, workflow_service): + def test_get_all_published_workflow_no_workflow_id(self, workflow_service: WorkflowService): """Test get_all_published_workflow returns empty when app has no workflow_id.""" app = TestWorkflowAssociatedDataFactory.create_app_mock(workflow_id=None) mock_session = MagicMock() @@ -998,7 +998,7 @@ class TestWorkflowService: # ==================== Update Workflow Tests ==================== # These tests verify updating workflow metadata (name, comments, etc.) - def test_update_workflow_success(self, workflow_service): + def test_update_workflow_success(self, workflow_service: WorkflowService): """ Test update_workflow updates workflow attributes. @@ -1032,7 +1032,7 @@ class TestWorkflowService: assert mock_workflow.marked_comment == "Updated Comment" assert mock_workflow.updated_by == account_id - def test_update_workflow_not_found(self, workflow_service): + def test_update_workflow_not_found(self, workflow_service: WorkflowService): """Test update_workflow returns None when workflow not found.""" mock_session = MagicMock() mock_session.scalar.return_value = None @@ -1055,7 +1055,7 @@ class TestWorkflowService: # ==================== Delete Workflow Tests ==================== # These tests verify workflow deletion with safety checks - def test_delete_workflow_success(self, workflow_service): + def test_delete_workflow_success(self, workflow_service: WorkflowService): """ Test delete_workflow successfully deletes a published workflow. @@ -1085,7 +1085,7 @@ class TestWorkflowService: assert result is True mock_session.delete.assert_called_once_with(mock_workflow) - def test_delete_workflow_draft_raises_error(self, workflow_service): + def test_delete_workflow_draft_raises_error(self, workflow_service: WorkflowService): """ Test delete_workflow raises error when trying to delete draft. @@ -1109,7 +1109,7 @@ class TestWorkflowService: with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow"): workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id) - def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service): + def test_delete_workflow_in_use_by_app_raises_error(self, workflow_service: WorkflowService): """ Test delete_workflow raises error when workflow is in use by app. @@ -1132,7 +1132,7 @@ class TestWorkflowService: with pytest.raises(WorkflowInUseError, match="currently in use by app"): workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id) - def test_delete_workflow_published_as_tool_raises_error(self, workflow_service): + def test_delete_workflow_published_as_tool_raises_error(self, workflow_service: WorkflowService): """ Test delete_workflow raises error when workflow is published as tool. @@ -1156,7 +1156,7 @@ class TestWorkflowService: with pytest.raises(WorkflowInUseError, match="published as a tool"): workflow_service.delete_workflow(session=mock_session, workflow_id=workflow_id, tenant_id=tenant_id) - def test_delete_workflow_not_found_raises_error(self, workflow_service): + def test_delete_workflow_not_found_raises_error(self, workflow_service: WorkflowService): """Test delete_workflow raises error when workflow not found.""" workflow_id = "nonexistent" tenant_id = "tenant-456" @@ -1175,7 +1175,7 @@ class TestWorkflowService: # ==================== Get Default Block Config Tests ==================== # These tests verify retrieval of default node configurations - def test_get_default_block_configs(self, workflow_service): + def test_get_default_block_configs(self, workflow_service: WorkflowService): """ Test get_default_block_configs returns list of default configs. @@ -1195,7 +1195,7 @@ class TestWorkflowService: assert len(result) > 0 - def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service): + def test_get_default_block_configs_http_request_injects_default_config(self, workflow_service: WorkflowService): injected_config = HttpRequestNodeConfig( max_connect_timeout=15, max_read_timeout=25, @@ -1234,7 +1234,7 @@ class TestWorkflowService: assert passed_http_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config mock_llm_node_class.get_default_config.assert_called_once_with(filters=None) - def test_get_default_block_config_for_node_type(self, workflow_service): + def test_get_default_block_config_for_node_type(self, workflow_service: WorkflowService): """ Test get_default_block_config returns config for specific node type. @@ -1258,7 +1258,7 @@ class TestWorkflowService: assert result == mock_config mock_node_class.get_default_config.assert_called_once() - def test_get_default_block_config_invalid_node_type(self, workflow_service): + def test_get_default_block_config_invalid_node_type(self, workflow_service: WorkflowService): """Test get_default_block_config returns empty dict for invalid node type.""" with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping: mock_mapping.return_value = {} @@ -1268,7 +1268,7 @@ class TestWorkflowService: assert result == {} - def test_get_default_block_config_http_request_injects_default_config(self, workflow_service): + def test_get_default_block_config_http_request_injects_default_config(self, workflow_service: WorkflowService): injected_config = HttpRequestNodeConfig( max_connect_timeout=11, max_read_timeout=22, @@ -1299,7 +1299,7 @@ class TestWorkflowService: passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"] assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is injected_config - def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service): + def test_get_default_block_config_http_request_uses_passed_config(self, workflow_service: WorkflowService): provided_config = HttpRequestNodeConfig( max_connect_timeout=13, max_read_timeout=23, @@ -1330,7 +1330,9 @@ class TestWorkflowService: passed_filters = mock_node_class.get_default_config.call_args.kwargs["filters"] assert passed_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] is provided_config - def test_get_default_block_config_http_request_malformed_config_raises_type_error(self, workflow_service): + def test_get_default_block_config_http_request_malformed_config_raises_type_error( + self, workflow_service: WorkflowService + ): with ( patch( "services.workflow_service.get_node_type_classes_mapping", @@ -1347,7 +1349,7 @@ class TestWorkflowService: # ==================== Workflow Conversion Tests ==================== # These tests verify converting basic apps to workflow apps - def test_convert_to_workflow_from_chat_app(self, workflow_service): + def test_convert_to_workflow_from_chat_app(self, workflow_service: WorkflowService): """ Test convert_to_workflow converts chat app to workflow. @@ -1374,7 +1376,7 @@ class TestWorkflowService: assert result == mock_new_app mock_converter.convert_to_workflow.assert_called_once() - def test_convert_to_workflow_from_completion_app(self, workflow_service): + def test_convert_to_workflow_from_completion_app(self, workflow_service: WorkflowService): """ Test convert_to_workflow converts completion app to workflow. @@ -1395,7 +1397,7 @@ class TestWorkflowService: assert result == mock_new_app - def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service): + def test_convert_to_workflow_invalid_mode_raises_error(self, workflow_service: WorkflowService): """ Test convert_to_workflow raises error for invalid app mode.