From 9b4dd9d4e8d82b57e5837d7ca7224ee902ccb559 Mon Sep 17 00:00:00 2001 From: Rohit Gahlawat <283466839+Rohit-Gahlawat@users.noreply.github.com> Date: Sun, 21 Jun 2026 06:23:36 +0530 Subject: [PATCH] refactor: accept db.session explicitly in APIBasedExtensionService (#37693) --- api/controllers/console/extension.py | 24 ++++-- api/services/api_based_extension_service.py | 30 +++---- .../console/test_api_based_extension.py | 3 +- .../test_api_based_extension_service.py | 79 +++++++++++-------- .../controllers/console/test_extension.py | 14 ++-- 5 files changed, 85 insertions(+), 65 deletions(-) diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 6d9362ae0b1..ec1e01dc460 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -7,6 +7,7 @@ from flask_restx import Resource from pydantic import BaseModel, Field, TypeAdapter, field_validator from constants import HIDDEN_VALUE +from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import to_timestamp from libs.login import login_required @@ -126,7 +127,7 @@ class APIBasedExtensionAPI(Resource): def get(self, current_tenant_id: str): return [ _serialize_api_based_extension(extension) - for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id) + for extension in APIBasedExtensionService.get_all_by_tenant_id(db.session(), current_tenant_id) ] @console_ns.doc("create_api_based_extension") @@ -147,7 +148,12 @@ class APIBasedExtensionAPI(Resource): api_key=payload.api_key, ) - return _serialize_saved_api_based_extension(APIBasedExtensionService.save(extension_data), payload.api_key), 201 + return ( + _serialize_saved_api_based_extension( + APIBasedExtensionService.save(db.session(), extension_data), payload.api_key + ), + 201, + ) @console_ns.route("/api-based-extension/") @@ -164,7 +170,7 @@ class APIBasedExtensionDetailAPI(Resource): api_based_extension_id = str(id) return _serialize_api_based_extension( - APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) + APIBasedExtensionService.get_with_tenant_id(db.session(), current_tenant_id, api_based_extension_id) ) @console_ns.doc("update_api_based_extension") @@ -179,7 +185,9 @@ class APIBasedExtensionDetailAPI(Resource): def post(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id( + db.session(), current_tenant_id, api_based_extension_id + ) payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {}) api_key_for_response = extension_data_from_db.api_key @@ -192,7 +200,7 @@ class APIBasedExtensionDetailAPI(Resource): api_key_for_response = payload.api_key return _serialize_saved_api_based_extension( - APIBasedExtensionService.save(extension_data_from_db), + APIBasedExtensionService.save(db.session(), extension_data_from_db), api_key_for_response, ) @@ -207,8 +215,10 @@ class APIBasedExtensionDetailAPI(Resource): def delete(self, current_tenant_id: str, id: UUID): api_based_extension_id = str(id) - extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id) + extension_data_from_db = APIBasedExtensionService.get_with_tenant_id( + db.session(), current_tenant_id, api_based_extension_id + ) - APIBasedExtensionService.delete(extension_data_from_db) + APIBasedExtensionService.delete(db.session(), extension_data_from_db) return "", 204 diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index fdb377694bb..25f554b6bdc 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -1,16 +1,16 @@ from sqlalchemy import select +from sqlalchemy.orm import Session from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token, encrypt_token -from extensions.ext_database import db from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: @staticmethod - def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: + def get_all_by_tenant_id(session: Session, tenant_id: str) -> list[APIBasedExtension]: extension_list = list( - db.session.scalars( + session.scalars( select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id) .order_by(APIBasedExtension.created_at.desc()) @@ -23,23 +23,23 @@ class APIBasedExtensionService: return extension_list @classmethod - def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension: - cls._validation(extension_data) + def save(cls, session: Session, extension_data: APIBasedExtension) -> APIBasedExtension: + cls._validation(session, extension_data) extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key) - db.session.add(extension_data) - db.session.commit() + session.add(extension_data) + session.commit() return extension_data @staticmethod - def delete(extension_data: APIBasedExtension): - db.session.delete(extension_data) - db.session.commit() + def delete(session: Session, extension_data: APIBasedExtension): + session.delete(extension_data) + session.commit() @staticmethod - def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.scalar( + def get_with_tenant_id(session: Session, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + extension = session.scalar( select(APIBasedExtension) .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id) .limit(1) @@ -53,14 +53,14 @@ class APIBasedExtensionService: return extension @classmethod - def _validation(cls, extension_data: APIBasedExtension): + def _validation(cls, session: Session, extension_data: APIBasedExtension): # name if not extension_data.name: raise ValueError("name must not be empty") if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.scalar( + is_name_existed = session.scalar( select(APIBasedExtension) .where( APIBasedExtension.tenant_id == extension_data.tenant_id, @@ -73,7 +73,7 @@ class APIBasedExtensionService: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.scalar( + is_name_existed = session.scalar( select(APIBasedExtension) .where( APIBasedExtension.tenant_id == extension_data.tenant_id, diff --git a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py index 058f4e5fa34..e60558040a5 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py +++ b/api/tests/test_containers_integration_tests/controllers/console/test_api_based_extension.py @@ -97,12 +97,13 @@ def test_list_scopes_api_based_extensions_to_authenticated_tenant( assert account_create_response.status_code == 201 APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=foreign_tenant_id, name="Foreign API", api_endpoint="https://foreign.example.com/hook", api_key="foreign-secret-12345", - ) + ), ) response = test_client_with_containers.get( diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index b8e022503fd..8bd4069639f 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -81,7 +81,7 @@ class TestAPIBasedExtensionService: ) # Save extension - saved_extension = APIBasedExtensionService.save(extension_data) + saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Verify extension was saved correctly assert saved_extension.id is not None @@ -119,21 +119,21 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test empty api_endpoint extension_data.name = fake.company() extension_data.api_endpoint = "" with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test empty api_key extension_data.api_endpoint = f"https://{fake.domain_name()}/api" extension_data.api_key = "" with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_get_all_by_tenant_id_success( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -157,11 +157,11 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - saved_extension = APIBasedExtensionService.save(extension_data) + saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) extensions.append(saved_extension) # Get all extensions for tenant - extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id) # Verify results assert len(extension_list) == 3 @@ -191,10 +191,12 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Get extension by ID - retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + retrieved_extension = APIBasedExtensionService.get_with_tenant_id( + db_session_with_containers, tenant.id, created_extension.id + ) # Verify extension was retrieved correctly assert retrieved_extension is not None @@ -219,7 +221,9 @@ class TestAPIBasedExtensionService: # Try to get non-existent extension with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) + APIBasedExtensionService.get_with_tenant_id( + db_session_with_containers, tenant.id, non_existent_extension_id + ) def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -238,11 +242,11 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) extension_id = created_extension.id # Delete the extension - APIBasedExtensionService.delete(created_extension) + APIBasedExtensionService.delete(db_session_with_containers, created_extension) # Verify extension was deleted @@ -270,7 +274,7 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - APIBasedExtensionService.save(extension_data1) + APIBasedExtensionService.save(db_session_with_containers, extension_data1) # Try to create second extension with same name extension_data2 = APIBasedExtension( tenant_id=tenant.id, @@ -280,7 +284,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService.save(extension_data2) + APIBasedExtensionService.save(db_session_with_containers, extension_data2) def test_save_extension_update_existing( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -301,7 +305,7 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Save original values for later comparison original_name = created_extension.name @@ -320,7 +324,7 @@ class TestAPIBasedExtensionService: created_extension.api_endpoint = new_endpoint created_extension.api_key = new_api_key - updated_extension = APIBasedExtensionService.save(created_extension) + updated_extension = APIBasedExtensionService.save(db_session_with_containers, created_extension) # Verify extension was updated correctly assert updated_extension.id == created_extension.id @@ -336,7 +340,9 @@ class TestAPIBasedExtensionService: assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2 # Verify the update by retrieving the extension again - retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id) + retrieved_extension = APIBasedExtensionService.get_with_tenant_id( + db_session_with_containers, tenant.id, created_extension.id + ) assert retrieved_extension.name == new_name assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved @@ -367,7 +373,7 @@ class TestAPIBasedExtensionService: # Try to save extension with connection error with pytest.raises(ValueError, match="connection error: request timeout"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_invalid_api_key_length( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -390,7 +396,7 @@ class TestAPIBasedExtensionService: # Try to save extension with short API key with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ @@ -410,21 +416,21 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="name must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test with None api_endpoint extension_data.name = fake.company() extension_data.api_endpoint = None with pytest.raises(ValueError, match="api_endpoint must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) # Test with None api_key extension_data.api_endpoint = f"https://{fake.domain_name()}/api" extension_data.api_key = None with pytest.raises(ValueError, match="api_key must not be empty"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_get_all_by_tenant_id_empty_list( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -438,7 +444,7 @@ class TestAPIBasedExtensionService: ) # Get all extensions for tenant (none exist) - extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id) + extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id) # Verify empty list is returned assert len(extension_list) == 0 @@ -468,7 +474,7 @@ class TestAPIBasedExtensionService: # Try to save extension with invalid ping response with pytest.raises(ValueError, match="{'result': 'invalid'}"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_missing_ping_result( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -494,7 +500,7 @@ class TestAPIBasedExtensionService: # Try to save extension with missing ping result with pytest.raises(ValueError, match="{'status': 'ok'}"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_get_with_tenant_id_wrong_tenant( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -520,11 +526,11 @@ class TestAPIBasedExtensionService: api_key=fake.password(length=20), ) - created_extension = APIBasedExtensionService.save(extension_data) + created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data) # Try to get extension with wrong tenant ID with pytest.raises(ValueError, match="API based extension is not found"): - APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id) + APIBasedExtensionService.get_with_tenant_id(db_session_with_containers, tenant2.id, created_extension.id) def test_save_extension_api_key_exactly_four_chars_rejected( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -544,7 +550,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="api_key must be at least 5 characters"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_api_key_exactly_five_chars_accepted( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -563,7 +569,7 @@ class TestAPIBasedExtensionService: api_key="12345", ) - saved = APIBasedExtensionService.save(extension_data) + saved = APIBasedExtensionService.save(db_session_with_containers, extension_data) assert saved.id is not None def test_save_extension_requestor_constructor_error( @@ -586,7 +592,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="connection error: bad config"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_network_exception( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -610,7 +616,7 @@ class TestAPIBasedExtensionService: ) with pytest.raises(ValueError, match="connection error: network failure"): - APIBasedExtensionService.save(extension_data) + APIBasedExtensionService.save(db_session_with_containers, extension_data) def test_save_extension_update_duplicate_name_rejected( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -623,26 +629,28 @@ class TestAPIBasedExtensionService: assert tenant is not None ext1 = APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=tenant.id, name="Extension Alpha", api_endpoint=f"https://{fake.domain_name()}/api", api_key=fake.password(length=20), - ) + ), ) ext2 = APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=tenant.id, name="Extension Beta", api_endpoint=f"https://{fake.domain_name()}/api", api_key=fake.password(length=20), - ) + ), ) # Try to rename ext2 to ext1's name ext2.name = "Extension Alpha" with pytest.raises(ValueError, match="name must be unique, it is already existed"): - APIBasedExtensionService.save(ext2) + APIBasedExtensionService.save(db_session_with_containers, ext2) def test_get_all_returns_empty_for_different_tenant( self, db_session_with_containers: Session, mock_external_service_dependencies @@ -658,14 +666,15 @@ class TestAPIBasedExtensionService: assert tenant1 is not None APIBasedExtensionService.save( + db_session_with_containers, APIBasedExtension( tenant_id=tenant1.id, name=fake.company(), api_endpoint=f"https://{fake.domain_name()}/api", api_key=fake.password(length=20), - ) + ), ) assert tenant2 is not None - result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id) + result = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant2.id) assert result == [] diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py index 487cf8f54fd..bab825ca6f0 100644 --- a/api/tests/unit_tests/controllers/console/test_extension.py +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -3,7 +3,7 @@ from __future__ import annotations import builtins import uuid from datetime import UTC, datetime -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest from flask import Flask @@ -114,7 +114,7 @@ def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypat assert response[0]["name"] == "Weather API" assert response[0]["api_endpoint"] == extension.api_endpoint assert response[0]["api_key"].startswith(extension.api_key[:3]) - service_mock.assert_called_once_with("tenant-123") + service_mock.assert_called_once_with(ANY, "tenant-123") def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch): @@ -132,7 +132,7 @@ def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pyt response, status = APIBasedExtensionAPI().post() args, _ = save_mock.call_args - created_extension: APIBasedExtension = args[0] + created_extension: APIBasedExtension = args[1] assert created_extension.tenant_id == "tenant-123" assert created_extension.name == payload["name"] assert created_extension.api_endpoint == payload["api_endpoint"] @@ -157,7 +157,7 @@ def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatc assert response["id"] == extension.id assert response["name"] == extension.name - service_mock.assert_called_once_with("tenant-123", str(extension_id)) + service_mock.assert_called_once_with(ANY, "tenant-123", str(extension_id)) def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch): @@ -187,7 +187,7 @@ def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkey assert existing_extension.name == payload["name"] assert existing_extension.api_endpoint == payload["api_endpoint"] assert existing_extension.api_key == "keep-me" - save_mock.assert_called_once_with(existing_extension) + save_mock.assert_called_once_with(ANY, existing_extension) assert response["name"] == payload["name"] assert response["api_key"] == _masked_api_key("keep-me") @@ -217,7 +217,7 @@ def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flas response = APIBasedExtensionDetailAPI().post(extension_id) assert existing_extension.api_key == "new-secret" - save_mock.assert_called_once_with(existing_extension) + save_mock.assert_called_once_with(ANY, existing_extension) assert response["name"] == payload["name"] assert response["api_key"] == _masked_api_key(payload["api_key"]) @@ -239,6 +239,6 @@ def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeyp ): response, status = APIBasedExtensionDetailAPI().delete(extension_id) - delete_mock.assert_called_once_with(existing_extension) + delete_mock.assert_called_once_with(ANY, existing_extension) assert status == 204 assert response == ""