diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 75569eb596..ac0682486b 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -8,7 +8,7 @@ from flask_restx import Resource from graphon.enums import WorkflowExecutionStatus from pydantic import AliasChoices, BaseModel, Field, computed_field, field_validator from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest from controllers.common.helpers import FileInfo @@ -37,7 +37,7 @@ from models.model import IconType from services.app_dsl_service import AppDslService from services.app_service import AppService from services.enterprise.enterprise_service import EnterpriseService -from services.entities.dsl_entities import ImportMode +from services.entities.dsl_entities import ImportMode, ImportStatus from services.entities.knowledge_entities.knowledge_entities import ( DataSource, InfoList, @@ -623,7 +623,7 @@ class AppCopyApi(Resource): args = CopyAppPayload.model_validate(console_ns.payload or {}) - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) result = import_service.import_app( @@ -636,6 +636,13 @@ class AppCopyApi(Resource): icon=args.icon, icon_background=args.icon_background, ) + if result.status == ImportStatus.FAILED: + session.rollback() + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + session.rollback() + return result.model_dump(mode="json"), 202 + session.commit() # Inherit web app permission from original app if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 80bd7d1d8d..e91dc9cfe5 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -1,6 +1,6 @@ from flask_restx import Resource from pydantic import BaseModel, Field -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_models from controllers.console.app.wraps import get_app_model @@ -52,8 +52,9 @@ class AppImportApi(Resource): current_user, _ = current_account_with_tenant() args = AppImportPayload.model_validate(console_ns.payload) - # Create service with session - with sessionmaker(db.engine).begin() as session: + # AppDslService performs internal commits for some creation paths, so use a plain + # Session here instead of nesting it inside sessionmaker(...).begin(). + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Import app account = current_user @@ -69,6 +70,10 @@ class AppImportApi(Resource): icon_background=args.icon_background, app_id=args.app_id, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: # update web app setting as private EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") @@ -95,12 +100,15 @@ class AppImportConfirmApi(Resource): # Check user role first current_user, _ = current_account_with_tenant() - # Create service with session - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) # Confirm import account = current_user result = import_service.confirm_import(import_id=import_id, account=account) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() # Return appropriate status code based on result if result.status == ImportStatus.FAILED: @@ -117,7 +125,7 @@ class AppImportCheckDependenciesApi(Resource): @account_initialization_required @edit_permission_required def get(self, app_model: App): - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: import_service = AppDslService(session) result = import_service.check_dependencies(app_model=app_model) diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index da8d25c2eb..5e6ff87d62 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from typing import Any from flask import abort, request -from flask_restx import Resource, fields, marshal_with +from flask_restx import Resource, fields, marshal, marshal_with from graphon.enums import NodeType from graphon.file import File from graphon.graph_engine.manager import GraphEngineManager @@ -942,7 +942,6 @@ class PublishedAllWorkflowApi(Resource): @login_required @account_initialization_required @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_pagination_model) @edit_permission_required def get(self, app_model: App): """ @@ -970,9 +969,10 @@ class PublishedAllWorkflowApi(Resource): user_id=user_id, named_only=named_only, ) + serialized_workflows = marshal(workflows, workflow_fields_copy) return { - "items": workflows, + "items": serialized_workflows, "page": page, "limit": limit, "has_more": has_more, diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py index 6c15f9aa8b..915a11dcdd 100644 --- a/api/controllers/inner_api/app/dsl.py +++ b/api/controllers/inner_api/app/dsl.py @@ -9,7 +9,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field from sqlalchemy import select -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model from controllers.console.wraps import setup_required @@ -56,7 +56,7 @@ class EnterpriseAppDSLImport(Resource): account.set_tenant_id(workspace_id) - with sessionmaker(db.engine).begin() as session: + with Session(db.engine, expire_on_commit=False) as session: dsl_service = AppDslService(session) result = dsl_service.import_app( account=account, @@ -65,6 +65,10 @@ class EnterpriseAppDSLImport(Resource): name=args.name, description=args.description, ) + if result.status == ImportStatus.FAILED: + session.rollback() + else: + session.commit() if result.status == ImportStatus.FAILED: return result.model_dump(mode="json"), 400 diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py deleted file mode 100644 index 487178ff58..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_executor.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest - -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor - -CODE_LANGUAGE = "unsupported_language" - - -def test_unsupported_with_code_template(): - with pytest.raises(CodeExecutionError) as e: - CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) - assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" diff --git a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py b/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py deleted file mode 100644 index 25af312afa..0000000000 --- a/api/tests/integration_tests/workflow/nodes/code_executor/test_code_python3.py +++ /dev/null @@ -1,36 +0,0 @@ -from textwrap import dedent - -from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage -from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider -from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer - -CODE_LANGUAGE = CodeLanguage.PYTHON3 - - -def test_python3_plain(): - code = 'print("Hello World")' - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == "Hello World\n" - - -def test_python3_json(): - code = dedent(""" - import json - print(json.dumps({'Hello': 'World'})) - """) - result = CodeExecutor.execute_code(language=CODE_LANGUAGE, preload="", code=code) - assert result == '{"Hello": "World"}\n' - - -def test_python3_with_code_template(): - result = CodeExecutor.execute_workflow_code_template( - language=CODE_LANGUAGE, code=Python3CodeProvider.get_default_code(), inputs={"arg1": "Hello", "arg2": "World"} - ) - assert result == {"result": "HelloWorld"} - - -def test_python3_get_runner_script(): - runner_script = Python3TemplateTransformer.get_runner_script() - assert runner_script.count(Python3TemplateTransformer._code_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._inputs_placeholder) == 1 - assert runner_script.count(Python3TemplateTransformer._result_tag) == 2 diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py index d8c6821f8d..25d19cf35a 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_app_import_api.py @@ -96,6 +96,56 @@ class TestAppImportApi: assert status == 200 assert response["status"] == ImportStatus.COMPLETED + def test_import_post_commits_session_on_success(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.commit.assert_called_once_with() + fake_session.rollback.assert_not_called() + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + def test_import_post_rolls_back_session_on_failure(self, app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + fake_session.rollback.assert_called_once_with() + fake_session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + class TestAppImportConfirmApi: @pytest.fixture diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py new file mode 100644 index 0000000000..4e884626a7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/controllers/service_api/test_site.py @@ -0,0 +1,110 @@ +""" +Testcontainers integration tests for Service API Site controller. +""" + +from __future__ import annotations + +import pytest +from flask import Flask +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.service_api.app.site import AppSiteApi +from models.account import Tenant, TenantStatus +from models.model import App, AppMode, Site + + +@pytest.fixture +def app(flask_app_with_containers) -> Flask: + return flask_app_with_containers + + +def _unwrap(method): + fn = method + while hasattr(fn, "__wrapped__"): + fn = fn.__wrapped__ + return fn + + +def _create_tenant(db_session: Session, *, status: TenantStatus = TenantStatus.NORMAL) -> Tenant: + tenant = Tenant(name="service-api-site-tenant", status=status) + db_session.add(tenant) + db_session.commit() + return tenant + + +def _create_app(db_session: Session, tenant_id: str) -> App: + app_model = App( + tenant_id=tenant_id, + mode=AppMode.CHAT, + name="service-api-site-app", + enable_site=True, + enable_api=True, + status="normal", + ) + db_session.add(app_model) + db_session.commit() + return app_model + + +def _create_site(db_session: Session, app_id: str) -> Site: + site = Site( + app_id=app_id, + title="Service API Site", + icon_type="emoji", + icon="robot", + icon_background="#ffffff", + description="Service API test site", + default_language="en-US", + prompt_public=True, + show_workflow_steps=True, + customize_token_strategy="not_allow", + use_icon_as_answer_icon=False, + chat_color_theme="light", + chat_color_theme_inverted=False, + ) + db_session.add(site) + db_session.commit() + return site + + +class TestAppSiteApi: + def test_get_site_success(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + response = _unwrap(api.get)(api, app_model=app_model) + + assert response["title"] == "Service API Site" + assert response["icon"] == "robot" + assert response["description"] == "Service API test site" + + def test_get_site_not_found(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) + + def test_get_site_tenant_archived(self, app: Flask, db_session_with_containers: Session) -> None: + tenant = _create_tenant(db_session_with_containers) + app_model = _create_app(db_session_with_containers, tenant.id) + _create_site(db_session_with_containers, app_model.id) + + archived_tenant = db_session_with_containers.get(Tenant, tenant.id) + assert archived_tenant is not None + archived_tenant.status = TenantStatus.ARCHIVE + db_session_with_containers.commit() + + app_model = db_session_with_containers.get(App, app_model.id) + assert app_model is not None + + with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test-token"}): + api = AppSiteApi() + with pytest.raises(Forbidden): + _unwrap(api.get)(api, app_model=app_model) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 0000000000..9c4678aed3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,139 @@ +"""Unit tests for console app import endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def _mock_session(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + fake_session = MagicMock() + fake_session.__enter__.return_value = fake_session + fake_session.__exit__.return_value = None + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(app_import_module, "Session", lambda *_args, **_kwargs: fake_session) + return fake_session + + +class TestAppImportApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportApi() + + def test_import_post_returns_failed_status_and_rolls_back(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + def test_import_post_returns_pending_status_and_commits(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=False) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + def test_import_post_updates_webapp_auth_when_enabled(self, api, app, monkeypatch: pytest.MonkeyPatch) -> None: + method = _unwrap(api.post) + + _install_features(monkeypatch, enabled=True) + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once_with() + session.rollback.assert_not_called() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +class TestAppImportConfirmApi: + @pytest.fixture + def api(self): + return app_import_module.AppImportConfirmApi() + + def test_import_confirm_returns_failed_status_and_rolls_back( + self, api, app, monkeypatch: pytest.MonkeyPatch + ) -> None: + method = _unwrap(api.post) + + session = _mock_session(monkeypatch) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.rollback.assert_called_once_with() + session.commit.assert_not_called() + assert status == 400 + assert response["status"] == ImportStatus.FAILED diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 3607636880..f32d0ef0ec 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -258,6 +258,63 @@ def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( assert exc.value.description == "invalid workflow graph" +def test_get_published_workflows_marshals_items_before_session_closes(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.PublishedAllWorkflowApi() + handler = _unwrap(api.get) + + session_state = {"open": False} + + class _SessionContext: + def __enter__(self): + session_state["open"] = True + return object() + + def __exit__(self, exc_type, exc, tb): + session_state["open"] = False + return False + + class _SessionMaker: + def begin(self): + return _SessionContext() + + class _Workflow: + @property + def id(self): + assert session_state["open"] is True + return "w1" + + monkeypatch.setattr(workflow_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(workflow_module, "sessionmaker", lambda *_args, **_kwargs: _SessionMaker()) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + get_all_published_workflow=lambda **_kwargs: ([_Workflow()], False), + ), + ) + + def _fake_marshal(items, fields): + assert session_state["open"] is True + return [{"id": item.id} for item in items] + + monkeypatch.setattr(workflow_module, "marshal", _fake_marshal) + + with app.test_request_context( + "/apps/app/workflows", + method="GET", + query_string={"page": 1, "limit": 10, "user_id": "", "named_only": "false"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app", workflow_id="wf-1")) + + assert response == { + "items": [{"id": "w1"}], + "page": 1, + "limit": 10, + "has_more": False, + } + + def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py index 974d8f7bc6..b7419009f0 100644 --- a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -102,16 +102,16 @@ class TestEnterpriseAppDSLImport: @pytest.fixture def _mock_import_deps(self): - """Patch db, sessionmaker, and AppDslService for import handler tests.""" - mock_session_ctx = MagicMock() - mock_session_ctx.__enter__ = MagicMock(return_value=MagicMock()) - mock_session_ctx.__exit__ = MagicMock(return_value=False) - mock_sessionmaker = MagicMock(return_value=MagicMock(begin=MagicMock(return_value=mock_session_ctx))) + """Patch db, Session, and AppDslService for import handler tests.""" + mock_session = MagicMock() + mock_session.__enter__ = MagicMock(return_value=mock_session) + mock_session.__exit__ = MagicMock(return_value=False) with ( patch("controllers.inner_api.app.dsl.db"), - patch("controllers.inner_api.app.dsl.sessionmaker", mock_sessionmaker), + patch("controllers.inner_api.app.dsl.Session", return_value=mock_session), patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, ): + self._mock_session = mock_session self._mock_dsl = MagicMock() mock_dsl_cls.return_value = self._mock_dsl yield @@ -147,6 +147,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 200 assert body["status"] == "completed" mock_account.set_tenant_id.assert_called_once_with("ws-123") + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -162,6 +164,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 202 assert body["status"] == "pending" + self._mock_session.commit.assert_called_once_with() + self._mock_session.rollback.assert_not_called() @pytest.mark.usefixtures("_mock_import_deps") @patch("controllers.inner_api.app.dsl._get_active_account") @@ -177,6 +181,8 @@ class TestEnterpriseAppDSLImport: assert status_code == 400 assert body["status"] == "failed" + self._mock_session.rollback.assert_called_once_with() + self._mock_session.commit.assert_not_called() @patch("controllers.inner_api.app.dsl._get_active_account") def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py deleted file mode 100644 index c0b40d070a..0000000000 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ /dev/null @@ -1,270 +0,0 @@ -""" -Unit tests for Service API Site controller -""" - -import uuid -from unittest.mock import Mock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.service_api.app.site import AppSiteApi -from models.account import TenantStatus -from models.model import App, Site -from tests.unit_tests.conftest import setup_mock_tenant_account_query - - -class TestAppSiteApi: - """Test suite for AppSiteApi""" - - @pytest.fixture - def mock_app_model(self): - """Create a mock App model with tenant.""" - app = Mock(spec=App) - app.id = str(uuid.uuid4()) - app.tenant_id = str(uuid.uuid4()) - app.status = "normal" - app.enable_api = True - - mock_tenant = Mock() - mock_tenant.id = app.tenant_id - mock_tenant.status = TenantStatus.NORMAL - app.tenant = mock_tenant - - return app - - @pytest.fixture - def mock_site(self): - """Create a mock Site model.""" - site = Mock(spec=Site) - site.id = str(uuid.uuid4()) - site.app_id = str(uuid.uuid4()) - site.title = "Test Site" - site.icon = "icon-url" - site.icon_background = "#ffffff" - site.description = "Site description" - site.copyright = "Copyright 2024" - site.privacy_policy = "Privacy policy text" - site.custom_disclaimer = "Custom disclaimer" - site.default_language = "en-US" - site.prompt_public = True - site.show_workflow_steps = True - site.use_icon_as_answer_icon = False - site.chat_color_theme = "light" - site.chat_color_theme_inverted = False - site.icon_type = "image" - site.created_at = "2024-01-01T00:00:00" - site.updated_at = "2024-01-01T00:00:00" - return site - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_success( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test successful retrieval of site configuration.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - # Mock wraps.db for authentication - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site.db for site query - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - response = api.get() - - # Assert - assert response["title"] == "Test Site" - assert response["icon"] == "icon-url" - assert response["description"] == "Site description" - mock_db.session.scalar.assert_called_once() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_not_found( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - ): - """Test that Forbidden is raised when site is not found.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query to return None - mock_db.session.scalar.return_value = None - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_tenant_archived( - self, - mock_wraps_db, - mock_validate_token, - mock_current_app, - mock_db, - mock_user_logged_in, - app, - mock_app_model, - mock_site, - ): - """Test that Forbidden is raised when tenant is archived.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - # Mock site query - mock_db.session.scalar.return_value = mock_site - - # Set tenant status to archived AFTER authentication - mock_app_model.tenant.status = TenantStatus.ARCHIVE - - # Act & Assert - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - with pytest.raises(Forbidden): - api.get() - - @patch("controllers.service_api.wraps.user_logged_in") - @patch("controllers.service_api.app.site.db") - @patch("controllers.service_api.wraps.current_app") - @patch("controllers.service_api.wraps.validate_and_get_api_token") - @patch("controllers.service_api.wraps.db") - def test_get_site_queries_by_app_id( - self, mock_wraps_db, mock_validate_token, mock_current_app, mock_db, mock_user_logged_in, app, mock_app_model - ): - """Test that site is queried using the app model's id.""" - # Arrange - mock_current_app.login_manager = Mock() - - # Mock authentication - mock_api_token = Mock() - mock_api_token.app_id = mock_app_model.id - mock_api_token.tenant_id = mock_app_model.tenant_id - mock_validate_token.return_value = mock_api_token - - mock_tenant = Mock() - mock_tenant.status = TenantStatus.NORMAL - mock_app_model.tenant = mock_tenant - - mock_wraps_db.session.get.side_effect = [ - mock_app_model, - mock_tenant, - ] - - mock_account = Mock() - mock_account.current_tenant = mock_tenant - setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) - - mock_site = Mock(spec=Site) - mock_site.id = str(uuid.uuid4()) - mock_site.app_id = mock_app_model.id - mock_site.title = "Test Site" - mock_site.icon = "icon-url" - mock_site.icon_background = "#ffffff" - mock_site.description = "Site description" - mock_site.copyright = "Copyright 2024" - mock_site.privacy_policy = "Privacy policy text" - mock_site.custom_disclaimer = "Custom disclaimer" - mock_site.default_language = "en-US" - mock_site.prompt_public = True - mock_site.show_workflow_steps = True - mock_site.use_icon_as_answer_icon = False - mock_site.chat_color_theme = "light" - mock_site.chat_color_theme_inverted = False - mock_site.icon_type = "image" - mock_site.created_at = "2024-01-01T00:00:00" - mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.scalar.return_value = mock_site - - # Act - with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): - api = AppSiteApi() - api.get() - - # Assert - # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.scalar.assert_called_once()