refactor: accept db.session explicitly in ApiKeyAuthService (#37832)

Co-authored-by: kunalj1-arch <kunal.j1@turing.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
kunal 2026-06-24 06:50:47 +05:30 committed by GitHub
parent 50b3228bc7
commit d0b2239c60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 68 additions and 55 deletions

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
from controllers.common.fields import SimpleResultResponse
from controllers.common.schema import register_response_schema_models, register_schema_models
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.login import login_required
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -58,7 +59,7 @@ class ApiKeyAuthDataSource(Resource):
@account_initialization_required
@with_current_tenant_id
def get(self, current_tenant_id: str):
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(db.session(), current_tenant_id)
if data_source_api_key_bindings:
return {
"sources": [
@ -92,7 +93,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
data = payload.model_dump()
ApiKeyAuthService.validate_api_key_auth_args(data)
try:
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
ApiKeyAuthService.create_provider_auth(db.session(), current_tenant_id, data)
except Exception as e:
raise ApiKeyAuthFailedError(str(e))
return {"result": "success"}, 200
@ -109,6 +110,6 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
@with_current_tenant_id
def delete(self, current_tenant_id: str, binding_id: UUID):
# The role of the current user in the table must be admin or owner
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
ApiKeyAuthService.delete_provider_auth(db.session(), current_tenant_id, str(binding_id))
return "", 204

View File

@ -2,17 +2,17 @@ import json
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = db.session.scalars(
def get_provider_auth_list(session: Session, tenant_id: str):
data_source_api_key_bindings = session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
@ -20,7 +20,7 @@ class ApiKeyAuthService:
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
def create_provider_auth(session: Session, tenant_id: str, args: dict[str, Any]):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
# Encrypt the api key
@ -31,12 +31,12 @@ class ApiKeyAuthService:
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
)
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
session.add(data_source_api_key_binding)
session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.scalar(
def get_auth_credentials(session: Session, tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = session.scalar(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
@ -52,16 +52,16 @@ class ApiKeyAuthService:
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.scalar(
def delete_provider_auth(session: Session, tenant_id: str, binding_id: str):
data_source_api_key_binding = session.scalar(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id,
)
)
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)
db.session.commit()
session.delete(data_source_api_key_binding)
session.commit()
@classmethod
def validate_api_key_auth_args(cls, args):

View File

@ -1,7 +1,7 @@
"""Controller integration tests for API key data source auth routes."""
import json
from unittest.mock import patch
from unittest.mock import ANY, patch
from flask.testing import FlaskClient
from sqlalchemy import select
@ -85,7 +85,7 @@ def test_create_binding_successful(
assert response.status_code == 200
assert response.get_json() == {"result": "success"}
create_auth.assert_called_once_with(tenant_id, payload)
create_auth.assert_called_once_with(ANY, tenant_id, payload)
def test_create_binding_failure(

View File

@ -51,7 +51,7 @@ class TestApiKeyAuthService:
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
assert len(result) >= 1
tenant_results = [r for r in result if r.tenant_id == tenant_id]
@ -61,7 +61,7 @@ class TestApiKeyAuthService:
def test_get_provider_auth_list_empty(
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id
):
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert tenant_results == []
@ -74,7 +74,7 @@ class TestApiKeyAuthService:
)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
tenant_results = [r for r in result if r.tenant_id == tenant_id]
assert tenant_results == []
@ -95,7 +95,7 @@ class TestApiKeyAuthService:
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123"
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
mock_factory.assert_called_once()
mock_auth_instance.validate_credentials.assert_called_once()
@ -118,7 +118,7 @@ class TestApiKeyAuthService:
mock_auth_instance.validate_credentials.return_value = False
mock_factory.return_value = mock_auth_instance
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
db_session_with_containers.expire_all()
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all()
@ -142,7 +142,7 @@ class TestApiKeyAuthService:
original_key = mock_args["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123"
assert mock_args["credentials"]["config"]["api_key"] != original_key
@ -166,14 +166,14 @@ class TestApiKeyAuthService:
)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
assert result == mock_credentials
def test_get_auth_credentials_not_found(
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, category, provider
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
assert result is None
@ -190,7 +190,7 @@ class TestApiKeyAuthService:
)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
assert result == special_credentials
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
@ -204,7 +204,7 @@ class TestApiKeyAuthService:
binding_id = binding.id
db_session_with_containers.expire_all()
ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id)
ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, binding_id)
db_session_with_containers.expire_all()
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
@ -214,7 +214,7 @@ class TestApiKeyAuthService:
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id
):
# Should not raise when binding not found
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))
ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, str(uuid4()))
def test_validate_api_key_auth_args_success(self, mock_args):
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
@ -288,16 +288,16 @@ class TestApiKeyAuthService:
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.return_value = "encrypted_key"
with patch("services.auth.api_key_auth_service.db.session") as mock_session:
mock_session.commit.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
mock_session = MagicMock()
mock_session.commit.side_effect = Exception("Database error")
with pytest.raises(Exception, match="Database error"):
ApiKeyAuthService.create_provider_auth(mock_session, tenant_id, mock_args)
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args):
mock_factory.side_effect = Exception("Factory error")
with pytest.raises(Exception, match="Factory error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args)
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
@patch("services.auth.api_key_auth_service.encrypter")
@ -307,7 +307,7 @@ class TestApiKeyAuthService:
mock_factory.return_value = mock_auth_instance
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
with pytest.raises(Exception, match="Encryption error"):
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args)
def test_validate_api_key_auth_args_none_input(self):
with pytest.raises(TypeError):

View File

@ -13,6 +13,7 @@ import pytest
from flask import Flask
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService
@ -56,7 +57,7 @@ class TestAuthIntegration:
mock_encrypt.return_value = "encrypted_fc_test_key_123"
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args)
mock_http.assert_called_once()
call_args = mock_http.call_args
@ -100,15 +101,15 @@ class TestAuthIntegration:
mock_encrypt.return_value = "encrypted_key"
args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_1, args1)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args1)
args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_2, args2)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_2, args2)
db_session_with_containers.expire_all()
result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1)
result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2)
result1 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_1)
result2 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_2)
assert len(result1) == 1
assert result1[0].tenant_id == tenant_id_1
@ -118,7 +119,9 @@ class TestAuthIntegration:
def test_cross_tenant_access_prevention(
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id_2, category
):
result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL)
result = ApiKeyAuthService.get_auth_credentials(
db_session_with_containers, tenant_id_2, category, AuthType.FIRECRAWL
)
assert result is None
@ -160,7 +163,7 @@ class TestAuthIntegration:
"provider": AuthType.FIRECRAWL,
"credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}},
}
ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args)
ApiKeyAuthService.create_provider_auth(db.session(), tenant_id_1, thread_args)
results.append("success")
except Exception as e:
exceptions.append(e)
@ -213,7 +216,7 @@ class TestAuthIntegration:
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args)
db_session_with_containers.expire_all()
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all()
@ -250,11 +253,13 @@ class TestAuthIntegration:
mock_encrypt.return_value = "encrypted_key"
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args)
db_session_with_containers.expire_all()
result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL)
result = ApiKeyAuthService.get_auth_credentials(
db_session_with_containers, tenant_id_1, category, AuthType.FIRECRAWL
)
assert result is not None
assert result["config"]["api_key"] == "encrypted_key"

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from datetime import UTC, datetime
from inspect import unwrap
from types import SimpleNamespace
from unittest.mock import PropertyMock, patch
from unittest.mock import ANY, PropertyMock, patch
from controllers.console import console_ns
from controllers.console.auth.data_source_bearer_auth import (
@ -34,13 +34,16 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None:
updated_at=datetime(2026, 1, 2, tzinfo=UTC),
)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
return_value=[binding],
) as get_provider_auth_list:
with (
patch("controllers.console.auth.data_source_bearer_auth.db"),
patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
return_value=[binding],
) as get_provider_auth_list,
):
result = method(api, "tenant-1")
get_provider_auth_list.assert_called_once_with("tenant-1")
get_provider_auth_list.assert_called_once_with(ANY, "tenant-1")
assert result["sources"][0]["id"] == "binding-1"
assert result["sources"][0]["provider"] == "custom"
@ -56,12 +59,13 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
with (
_payload_patch(payload),
patch("controllers.console.auth.data_source_bearer_auth.db"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
):
result, status = method(api, "tenant-1")
create_auth.assert_called_once_with("tenant-1", payload)
create_auth.assert_called_once_with(ANY, "tenant-1", payload)
assert result == {"result": "success"}
assert status == 200
@ -70,11 +74,14 @@ def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
api = ApiKeyAuthDataSourceBindingDelete()
method = unwrap(api.delete)
with patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
) as delete_provider_auth:
with (
patch("controllers.console.auth.data_source_bearer_auth.db"),
patch(
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
) as delete_provider_auth,
):
result, status = method(api, "tenant-1", "binding-1")
delete_provider_auth.assert_called_once_with("tenant-1", "binding-1")
delete_provider_auth.assert_called_once_with(ANY, "tenant-1", "binding-1")
assert result == ""
assert status == 204