From 7a880ae60c7c1251a0116d733a74a60cdd0ed902 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Wed, 15 Apr 2026 12:43:22 +0800 Subject: [PATCH] fix: import DSL and copy app not work (#35239) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/app/app.py | 13 +- api/controllers/console/app/app_import.py | 20 ++- api/controllers/inner_api/app/dsl.py | 8 +- .../console/app/test_app_import_api.py | 50 +++++++ .../console/app/test_app_import_api.py | 139 ++++++++++++++++++ .../controllers/inner_api/app/test_dsl.py | 18 ++- 6 files changed, 231 insertions(+), 17 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/app/test_app_import_api.py diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 75569eb596..ac0682486b 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -8,7 +8,7 @@ from flask_restx import Resource from graphon.enums import WorkflowExecutionStatus from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from controllers.common.helpers import FileInfo @@ -37,7 +37,7 @@ from models.model import IconType from services.app_dsl_service import AppDslService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService -from services.entities.dsl_entities import ImportMode +from services.entities.dsl_entities import ImportMode, ImportStatus from services.entities.knowledge_entities.knowledge_entities import ( DataSource, InfoList, @@ -623,7 +623,7 @@ class AppCopyApi(Resource): args = CopyAppPayload.model_validate(console_ns.payload or {}) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) result = import_service.import_app( @@ -636,6 +636,13 @@ class AppCopyApi(Resource): icon=args.icon, icon_background=args.icon_background, ) + if result.status == ImportStatus.FAILED: + session.rollback() + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + session.rollback() + return result.model_dump(mode="json"), 202 + session.commit() # Inherit web app permission from original app if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 80bd7d1d8d..e91dc9cfe5 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,6 @@ from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_models from controllers.console.app.wraps import get_app_model @@ -52,8 +52,9 @@ class AppImportApi(Resource): current_user, _ = current_account_with_tenant() args = AppImportPayload.model_validate(console_ns.payload) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # AppDslService performs internal commits for some creation paths, so use a plain + # Session here instead of nesting it inside sessionmaker(...).begin(). + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Import app account = current_user @@ -69,6 +70,10 @@ class AppImportApi(Resource): icon_background=args.icon_background, app_id=args.app_id, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") @@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource): # Check user role first current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource): @account_initialization_required @edit_permission_required def get(self, app_model: App): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 6c15f9aa8b..915a11dcdd 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -9,7 +9,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required @@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource): account.set_tenant_id(workspace_id) - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: dsl_service = AppDslService(session) result = dsl_service.import_app( account=account, @@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource): name=args.name, description=args.description, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index d8c6821f8d..25d19cf35a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -96,6 +96,56 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED + def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.commit.assert_called_once_with() + fake_session.rollback.assert_not_called() + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.rollback.assert_called_once_with() + fake_session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + class TestAppImportConfirmApi: @pytest.fixture diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 0000000000..9c4678aed3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,139 @@ +"""Unit tests for console app import endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + return fake_session + + +class TestAppImportApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportApi() + + def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=True) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +class TestAppImportConfirmApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportConfirmApi() + + def test_import_confirm_returns_failed_status_and_rolls_back( + self, api, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + method = _unwrap(api.post) + + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 974d8f7bc6..b7419009f0 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -102,16 +102,16 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): - """Patch db, sessionmaker, and AppDslService for import handler tests.""" - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock()) - mock_session_ctx.__exit__ = MagicMock(return_value=False) - mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx))) + """Patch db, Session, and AppDslService for import handler tests.""" + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=False) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker), + patch("controllers.inner_api.app.dsl.Session", return_value=mock_session), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): + self._mock_session = mock_session self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield @@ -147,6 +147,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 200 assert body["status"] == "completed" mock_account.set_tenant_id.assert_called_once_with("ws-123") + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -162,6 +164,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 202 assert body["status"] == "pending" + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -177,6 +181,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 400 assert body["status"] == "failed" + self._mock_session.rollback.assert_called_once_with() + self._mock_session.commit.assert_not_called() @patch("controllers.inner_api.app.dsl._get_active_account") def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):