refactor: accept db.session explicitly in APIBasedExtensionService (#37693)

This commit is contained in:
Rohit Gahlawat 2026-06-21 06:23:36 +05:30 committed by GitHub
parent 75d50455d6
commit 9b4dd9d4e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 85 additions and 65 deletions

View File

@ -7,6 +7,7 @@ from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from constants import HIDDEN_VALUE
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import to_timestamp
from libs.login import login_required
@ -126,7 +127,7 @@ class APIBasedExtensionAPI(Resource):
def get(self, current_tenant_id: str):
return [
_serialize_api_based_extension(extension)
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
for extension in APIBasedExtensionService.get_all_by_tenant_id(db.session(), current_tenant_id)
]
@console_ns.doc("create_api_based_extension")
@ -147,7 +148,12 @@ class APIBasedExtensionAPI(Resource):
api_key=payload.api_key,
)
return _serialize_saved_api_based_extension(APIBasedExtensionService.save(extension_data), payload.api_key), 201
return (
_serialize_saved_api_based_extension(
APIBasedExtensionService.save(db.session(), extension_data), payload.api_key
),
201,
)
@console_ns.route("/api-based-extension/<uuid:id>")
@ -164,7 +170,7 @@ class APIBasedExtensionDetailAPI(Resource):
api_based_extension_id = str(id)
return _serialize_api_based_extension(
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
APIBasedExtensionService.get_with_tenant_id(db.session(), current_tenant_id, api_based_extension_id)
)
@console_ns.doc("update_api_based_extension")
@ -179,7 +185,9 @@ class APIBasedExtensionDetailAPI(Resource):
def post(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(
db.session(), current_tenant_id, api_based_extension_id
)
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
api_key_for_response = extension_data_from_db.api_key
@ -192,7 +200,7 @@ class APIBasedExtensionDetailAPI(Resource):
api_key_for_response = payload.api_key
return _serialize_saved_api_based_extension(
APIBasedExtensionService.save(extension_data_from_db),
APIBasedExtensionService.save(db.session(), extension_data_from_db),
api_key_for_response,
)
@ -207,8 +215,10 @@ class APIBasedExtensionDetailAPI(Resource):
def delete(self, current_tenant_id: str, id: UUID):
api_based_extension_id = str(id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(
db.session(), current_tenant_id, api_based_extension_id
)
APIBasedExtensionService.delete(extension_data_from_db)
APIBasedExtensionService.delete(db.session(), extension_data_from_db)
return "", 204

View File

@ -1,16 +1,16 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token, encrypt_token
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class APIBasedExtensionService:
@staticmethod
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
def get_all_by_tenant_id(session: Session, tenant_id: str) -> list[APIBasedExtension]:
extension_list = list(
db.session.scalars(
session.scalars(
select(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id)
.order_by(APIBasedExtension.created_at.desc())
@ -23,23 +23,23 @@ class APIBasedExtensionService:
return extension_list
@classmethod
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
cls._validation(extension_data)
def save(cls, session: Session, extension_data: APIBasedExtension) -> APIBasedExtension:
cls._validation(session, extension_data)
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
db.session.add(extension_data)
db.session.commit()
session.add(extension_data)
session.commit()
return extension_data
@staticmethod
def delete(extension_data: APIBasedExtension):
db.session.delete(extension_data)
db.session.commit()
def delete(session: Session, extension_data: APIBasedExtension):
session.delete(extension_data)
session.commit()
@staticmethod
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.scalar(
def get_with_tenant_id(session: Session, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = session.scalar(
select(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.limit(1)
@ -53,14 +53,14 @@ class APIBasedExtensionService:
return extension
@classmethod
def _validation(cls, extension_data: APIBasedExtension):
def _validation(cls, session: Session, extension_data: APIBasedExtension):
# name
if not extension_data.name:
raise ValueError("name must not be empty")
if not extension_data.id:
# case one: check new data, name must be unique
is_name_existed = db.session.scalar(
is_name_existed = session.scalar(
select(APIBasedExtension)
.where(
APIBasedExtension.tenant_id == extension_data.tenant_id,
@ -73,7 +73,7 @@ class APIBasedExtensionService:
raise ValueError("name must be unique, it is already existed")
else:
# case two: check existing data, name must be unique
is_name_existed = db.session.scalar(
is_name_existed = session.scalar(
select(APIBasedExtension)
.where(
APIBasedExtension.tenant_id == extension_data.tenant_id,

View File

@ -97,12 +97,13 @@ def test_list_scopes_api_based_extensions_to_authenticated_tenant(
assert account_create_response.status_code == 201
APIBasedExtensionService.save(
db_session_with_containers,
APIBasedExtension(
tenant_id=foreign_tenant_id,
name="Foreign API",
api_endpoint="https://foreign.example.com/hook",
api_key="foreign-secret-12345",
)
),
)
response = test_client_with_containers.get(

View File

@ -81,7 +81,7 @@ class TestAPIBasedExtensionService:
)
# Save extension
saved_extension = APIBasedExtensionService.save(extension_data)
saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Verify extension was saved correctly
assert saved_extension.id is not None
@ -119,21 +119,21 @@ class TestAPIBasedExtensionService:
)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Test empty api_endpoint
extension_data.name = fake.company()
extension_data.api_endpoint = ""
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Test empty api_key
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = ""
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_get_all_by_tenant_id_success(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -157,11 +157,11 @@ class TestAPIBasedExtensionService:
api_key=fake.password(length=20),
)
saved_extension = APIBasedExtensionService.save(extension_data)
saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
extensions.append(saved_extension)
# Get all extensions for tenant
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id)
# Verify results
assert len(extension_list) == 3
@ -191,10 +191,12 @@ class TestAPIBasedExtensionService:
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Get extension by ID
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(
db_session_with_containers, tenant.id, created_extension.id
)
# Verify extension was retrieved correctly
assert retrieved_extension is not None
@ -219,7 +221,9 @@ class TestAPIBasedExtensionService:
# Try to get non-existent extension
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
APIBasedExtensionService.get_with_tenant_id(
db_session_with_containers, tenant.id, non_existent_extension_id
)
def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
@ -238,11 +242,11 @@ class TestAPIBasedExtensionService:
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
extension_id = created_extension.id
# Delete the extension
APIBasedExtensionService.delete(created_extension)
APIBasedExtensionService.delete(db_session_with_containers, created_extension)
# Verify extension was deleted
@ -270,7 +274,7 @@ class TestAPIBasedExtensionService:
api_key=fake.password(length=20),
)
APIBasedExtensionService.save(extension_data1)
APIBasedExtensionService.save(db_session_with_containers, extension_data1)
# Try to create second extension with same name
extension_data2 = APIBasedExtension(
tenant_id=tenant.id,
@ -280,7 +284,7 @@ class TestAPIBasedExtensionService:
)
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(extension_data2)
APIBasedExtensionService.save(db_session_with_containers, extension_data2)
def test_save_extension_update_existing(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -301,7 +305,7 @@ class TestAPIBasedExtensionService:
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Save original values for later comparison
original_name = created_extension.name
@ -320,7 +324,7 @@ class TestAPIBasedExtensionService:
created_extension.api_endpoint = new_endpoint
created_extension.api_key = new_api_key
updated_extension = APIBasedExtensionService.save(created_extension)
updated_extension = APIBasedExtensionService.save(db_session_with_containers, created_extension)
# Verify extension was updated correctly
assert updated_extension.id == created_extension.id
@ -336,7 +340,9 @@ class TestAPIBasedExtensionService:
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
# Verify the update by retrieving the extension again
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(
db_session_with_containers, tenant.id, created_extension.id
)
assert retrieved_extension.name == new_name
assert retrieved_extension.api_endpoint == new_endpoint
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
@ -367,7 +373,7 @@ class TestAPIBasedExtensionService:
# Try to save extension with connection error
with pytest.raises(ValueError, match="connection error: request timeout"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_save_extension_invalid_api_key_length(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -390,7 +396,7 @@ class TestAPIBasedExtensionService:
# Try to save extension with short API key
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies):
"""
@ -410,21 +416,21 @@ class TestAPIBasedExtensionService:
)
with pytest.raises(ValueError, match="name must not be empty"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Test with None api_endpoint
extension_data.name = fake.company()
extension_data.api_endpoint = None
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Test with None api_key
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
extension_data.api_key = None
with pytest.raises(ValueError, match="api_key must not be empty"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_get_all_by_tenant_id_empty_list(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -438,7 +444,7 @@ class TestAPIBasedExtensionService:
)
# Get all extensions for tenant (none exist)
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id)
# Verify empty list is returned
assert len(extension_list) == 0
@ -468,7 +474,7 @@ class TestAPIBasedExtensionService:
# Try to save extension with invalid ping response
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_save_extension_missing_ping_result(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -494,7 +500,7 @@ class TestAPIBasedExtensionService:
# Try to save extension with missing ping result
with pytest.raises(ValueError, match="{'status': 'ok'}"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_get_with_tenant_id_wrong_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -520,11 +526,11 @@ class TestAPIBasedExtensionService:
api_key=fake.password(length=20),
)
created_extension = APIBasedExtensionService.save(extension_data)
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
# Try to get extension with wrong tenant ID
with pytest.raises(ValueError, match="API based extension is not found"):
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
APIBasedExtensionService.get_with_tenant_id(db_session_with_containers, tenant2.id, created_extension.id)
def test_save_extension_api_key_exactly_four_chars_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -544,7 +550,7 @@ class TestAPIBasedExtensionService:
)
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_save_extension_api_key_exactly_five_chars_accepted(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -563,7 +569,7 @@ class TestAPIBasedExtensionService:
api_key="12345",
)
saved = APIBasedExtensionService.save(extension_data)
saved = APIBasedExtensionService.save(db_session_with_containers, extension_data)
assert saved.id is not None
def test_save_extension_requestor_constructor_error(
@ -586,7 +592,7 @@ class TestAPIBasedExtensionService:
)
with pytest.raises(ValueError, match="connection error: bad config"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_save_extension_network_exception(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -610,7 +616,7 @@ class TestAPIBasedExtensionService:
)
with pytest.raises(ValueError, match="connection error: network failure"):
APIBasedExtensionService.save(extension_data)
APIBasedExtensionService.save(db_session_with_containers, extension_data)
def test_save_extension_update_duplicate_name_rejected(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -623,26 +629,28 @@ class TestAPIBasedExtensionService:
assert tenant is not None
ext1 = APIBasedExtensionService.save(
db_session_with_containers,
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Alpha",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
),
)
ext2 = APIBasedExtensionService.save(
db_session_with_containers,
APIBasedExtension(
tenant_id=tenant.id,
name="Extension Beta",
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
),
)
# Try to rename ext2 to ext1's name
ext2.name = "Extension Alpha"
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
APIBasedExtensionService.save(ext2)
APIBasedExtensionService.save(db_session_with_containers, ext2)
def test_get_all_returns_empty_for_different_tenant(
self, db_session_with_containers: Session, mock_external_service_dependencies
@ -658,14 +666,15 @@ class TestAPIBasedExtensionService:
assert tenant1 is not None
APIBasedExtensionService.save(
db_session_with_containers,
APIBasedExtension(
tenant_id=tenant1.id,
name=fake.company(),
api_endpoint=f"https://{fake.domain_name()}/api",
api_key=fake.password(length=20),
)
),
)
assert tenant2 is not None
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
result = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant2.id)
assert result == []

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import builtins
import uuid
from datetime import UTC, datetime
from unittest.mock import MagicMock
from unittest.mock import ANY, MagicMock
import pytest
from flask import Flask
@ -114,7 +114,7 @@ def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypat
assert response[0]["name"] == "Weather API"
assert response[0]["api_endpoint"] == extension.api_endpoint
assert response[0]["api_key"].startswith(extension.api_key[:3])
service_mock.assert_called_once_with("tenant-123")
service_mock.assert_called_once_with(ANY, "tenant-123")
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
@ -132,7 +132,7 @@ def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pyt
response, status = APIBasedExtensionAPI().post()
args, _ = save_mock.call_args
created_extension: APIBasedExtension = args[0]
created_extension: APIBasedExtension = args[1]
assert created_extension.tenant_id == "tenant-123"
assert created_extension.name == payload["name"]
assert created_extension.api_endpoint == payload["api_endpoint"]
@ -157,7 +157,7 @@ def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatc
assert response["id"] == extension.id
assert response["name"] == extension.name
service_mock.assert_called_once_with("tenant-123", str(extension_id))
service_mock.assert_called_once_with(ANY, "tenant-123", str(extension_id))
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
@ -187,7 +187,7 @@ def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkey
assert existing_extension.name == payload["name"]
assert existing_extension.api_endpoint == payload["api_endpoint"]
assert existing_extension.api_key == "keep-me"
save_mock.assert_called_once_with(existing_extension)
save_mock.assert_called_once_with(ANY, existing_extension)
assert response["name"] == payload["name"]
assert response["api_key"] == _masked_api_key("keep-me")
@ -217,7 +217,7 @@ def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flas
response = APIBasedExtensionDetailAPI().post(extension_id)
assert existing_extension.api_key == "new-secret"
save_mock.assert_called_once_with(existing_extension)
save_mock.assert_called_once_with(ANY, existing_extension)
assert response["name"] == payload["name"]
assert response["api_key"] == _masked_api_key(payload["api_key"])
@ -239,6 +239,6 @@ def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeyp
):
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
delete_mock.assert_called_once_with(existing_extension)
delete_mock.assert_called_once_with(ANY, existing_extension)
assert status == 204
assert response == ""