mirror of
https://github.com/langgenius/dify.git
synced 2026-04-16 02:16:57 +08:00
fix: import DSL and copy app not work (#35239)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5bc0f9513b
commit
7a880ae60c
@ -8,7 +8,7 @@ from flask_restx import Resource
|
||||
from graphon.enums import WorkflowExecutionStatus
|
||||
from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.common.helpers import FileInfo
|
||||
@ -37,7 +37,7 @@ from models.model import IconType
|
||||
from services.app_dsl_service import AppDslService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.entities.dsl_entities import ImportMode
|
||||
from services.entities.dsl_entities import ImportMode, ImportStatus
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DataSource,
|
||||
InfoList,
|
||||
@ -623,7 +623,7 @@ class AppCopyApi(Resource):
|
||||
|
||||
args = CopyAppPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = import_service.import_app(
|
||||
@ -636,6 +636,13 @@ class AppCopyApi(Resource):
|
||||
icon=args.icon,
|
||||
icon_background=args.icon_background,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 400
|
||||
if result.status == ImportStatus.PENDING:
|
||||
session.rollback()
|
||||
return result.model_dump(mode="json"), 202
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
@ -52,8 +52,9 @@ class AppImportApi(Resource):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = AppImportPayload.model_validate(console_ns.payload)
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# AppDslService performs internal commits for some creation paths, so use a plain
|
||||
# Session here instead of nesting it inside sessionmaker(...).begin().
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Import app
|
||||
account = current_user
|
||||
@ -69,6 +70,10 @@ class AppImportApi(Resource):
|
||||
icon_background=args.icon_background,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
# update web app setting as private
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
|
||||
@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource):
|
||||
# Check user role first
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
# Create service with session
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
# Confirm import
|
||||
account = current_user
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED:
|
||||
@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource):
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
import_service = AppDslService(session)
|
||||
result = import_service.check_dependencies(app_model=app_model)
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.console.wraps import setup_required
|
||||
@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource):
|
||||
|
||||
account.set_tenant_id(workspace_id)
|
||||
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
dsl_service = AppDslService(session)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource):
|
||||
name=args.name,
|
||||
description=args.description,
|
||||
)
|
||||
if result.status == ImportStatus.FAILED:
|
||||
session.rollback()
|
||||
else:
|
||||
session.commit()
|
||||
|
||||
if result.status == ImportStatus.FAILED:
|
||||
return result.model_dump(mode="json"), 400
|
||||
|
||||
@ -96,6 +96,56 @@ class TestAppImportApi:
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
fake_session.__exit__.return_value = None
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
fake_session.commit.assert_called_once_with()
|
||||
fake_session.rollback.assert_not_called()
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
api = app_import_module.AppImportApi()
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
fake_session.__exit__.return_value = None
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
fake_session.rollback.assert_called_once_with()
|
||||
fake_session.commit.assert_not_called()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
|
||||
@ -0,0 +1,139 @@
|
||||
"""Unit tests for console app import endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.console.app import app_import as app_import_module
|
||||
from services.app_dsl_service import ImportStatus
|
||||
|
||||
|
||||
def _unwrap(func):
|
||||
bound_self = getattr(func, "__self__", None)
|
||||
while hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
if bound_self is not None:
|
||||
return func.__get__(bound_self, bound_self.__class__)
|
||||
return func
|
||||
|
||||
|
||||
class _Result:
|
||||
def __init__(self, status: ImportStatus, app_id: str | None = "app-1"):
|
||||
self.status = status
|
||||
self.app_id = app_id
|
||||
|
||||
def model_dump(self, mode: str = "json"):
|
||||
return {"status": self.status, "app_id": self.app_id}
|
||||
|
||||
|
||||
def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None:
|
||||
features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled))
|
||||
monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features)
|
||||
|
||||
|
||||
def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock:
|
||||
fake_session = MagicMock()
|
||||
fake_session.__enter__.return_value = fake_session
|
||||
fake_session.__exit__.return_value = None
|
||||
monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object()))
|
||||
monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session)
|
||||
return fake_session
|
||||
|
||||
|
||||
class TestAppImportApi:
|
||||
@pytest.fixture
|
||||
def api(self):
|
||||
return app_import_module.AppImportApi()
|
||||
|
||||
def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
|
||||
def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=False)
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.PENDING),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
assert status == 202
|
||||
assert response["status"] == ImportStatus.PENDING
|
||||
|
||||
def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
_install_features(monkeypatch, enabled=True)
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"import_app",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"),
|
||||
)
|
||||
update_access = MagicMock()
|
||||
monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}):
|
||||
response, status = method()
|
||||
|
||||
session.commit.assert_called_once_with()
|
||||
session.rollback.assert_not_called()
|
||||
update_access.assert_called_once_with("app-123", "private")
|
||||
assert status == 200
|
||||
assert response["status"] == ImportStatus.COMPLETED
|
||||
|
||||
|
||||
class TestAppImportConfirmApi:
|
||||
@pytest.fixture
|
||||
def api(self):
|
||||
return app_import_module.AppImportConfirmApi()
|
||||
|
||||
def test_import_confirm_returns_failed_status_and_rolls_back(
|
||||
self, api, app, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
method = _unwrap(api.post)
|
||||
|
||||
session = _mock_session(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
app_import_module.AppDslService,
|
||||
"confirm_import",
|
||||
lambda *_args, **_kwargs: _Result(ImportStatus.FAILED),
|
||||
)
|
||||
monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1"))
|
||||
|
||||
with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"):
|
||||
response, status = method(import_id="import-1")
|
||||
|
||||
session.rollback.assert_called_once_with()
|
||||
session.commit.assert_not_called()
|
||||
assert status == 400
|
||||
assert response["status"] == ImportStatus.FAILED
|
||||
@ -102,16 +102,16 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_import_deps(self):
|
||||
"""Patch db, sessionmaker, and AppDslService for import handler tests."""
|
||||
mock_session_ctx = MagicMock()
|
||||
mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx)))
|
||||
"""Patch db, Session, and AppDslService for import handler tests."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
with (
|
||||
patch("controllers.inner_api.app.dsl.db"),
|
||||
patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker),
|
||||
patch("controllers.inner_api.app.dsl.Session", return_value=mock_session),
|
||||
patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls,
|
||||
):
|
||||
self._mock_session = mock_session
|
||||
self._mock_dsl = MagicMock()
|
||||
mock_dsl_cls.return_value = self._mock_dsl
|
||||
yield
|
||||
@ -147,6 +147,8 @@ class TestEnterpriseAppDSLImport:
|
||||
assert status_code == 200
|
||||
assert body["status"] == "completed"
|
||||
mock_account.set_tenant_id.assert_called_once_with("ws-123")
|
||||
self._mock_session.commit.assert_called_once_with()
|
||||
self._mock_session.rollback.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
@ -162,6 +164,8 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
assert status_code == 202
|
||||
assert body["status"] == "pending"
|
||||
self._mock_session.commit.assert_called_once_with()
|
||||
self._mock_session.rollback.assert_not_called()
|
||||
|
||||
@pytest.mark.usefixtures("_mock_import_deps")
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
@ -177,6 +181,8 @@ class TestEnterpriseAppDSLImport:
|
||||
|
||||
assert status_code == 400
|
||||
assert body["status"] == "failed"
|
||||
self._mock_session.rollback.assert_called_once_with()
|
||||
self._mock_session.commit.assert_not_called()
|
||||
|
||||
@patch("controllers.inner_api.app.dsl._get_active_account")
|
||||
def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user