diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index 1de206c73db..11fab84a831 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field from controllers.common.fields import SimpleResultResponse from controllers.common.schema import register_response_schema_models, register_schema_models +from extensions.ext_database import db from fields.base import ResponseModel from libs.login import login_required from services.auth.api_key_auth_service import ApiKeyAuthService @@ -58,7 +59,7 @@ class ApiKeyAuthDataSource(Resource): @account_initialization_required @with_current_tenant_id def get(self, current_tenant_id: str): - data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id) + data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(db.session(), current_tenant_id) if data_source_api_key_bindings: return { "sources": [ @@ -92,7 +93,7 @@ class ApiKeyAuthDataSourceBinding(Resource): data = payload.model_dump() ApiKeyAuthService.validate_api_key_auth_args(data) try: - ApiKeyAuthService.create_provider_auth(current_tenant_id, data) + ApiKeyAuthService.create_provider_auth(db.session(), current_tenant_id, data) except Exception as e: raise ApiKeyAuthFailedError(str(e)) return {"result": "success"}, 200 @@ -109,6 +110,6 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): @with_current_tenant_id def delete(self, current_tenant_id: str, binding_id: UUID): # The role of the current user in the table must be admin or owner - ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id)) + ApiKeyAuthService.delete_provider_auth(db.session(), current_tenant_id, str(binding_id)) return "", 204 diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 36b15170567..42f1d4d8d40 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -2,17 +2,17 @@ import json from typing import Any from sqlalchemy import select +from sqlalchemy.orm import Session from core.helper import encrypter -from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: @staticmethod - def get_provider_auth_list(tenant_id: str): - data_source_api_key_bindings = db.session.scalars( + def get_provider_auth_list(session: Session, tenant_id: str): + data_source_api_key_bindings = session.scalars( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False) ) @@ -20,7 +20,7 @@ class ApiKeyAuthService: return data_source_api_key_bindings @staticmethod - def create_provider_auth(tenant_id: str, args: dict[str, Any]): + def create_provider_auth(session: Session, tenant_id: str, args: dict[str, Any]): auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key @@ -31,12 +31,12 @@ class ApiKeyAuthService: tenant_id=tenant_id, category=args["category"], provider=args["provider"] ) data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) - db.session.add(data_source_api_key_binding) - db.session.commit() + session.add(data_source_api_key_binding) + session.commit() @staticmethod - def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.scalar( + def get_auth_credentials(session: Session, tenant_id: str, category: str, provider: str): + data_source_api_key_bindings = session.scalar( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, @@ -52,16 +52,16 @@ class ApiKeyAuthService: return credentials @staticmethod - def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.scalar( + def delete_provider_auth(session: Session, tenant_id: str, binding_id: str): + data_source_api_key_binding = session.scalar( select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id, ) ) if data_source_api_key_binding: - db.session.delete(data_source_api_key_binding) - db.session.commit() + session.delete(data_source_api_key_binding) + session.commit() @classmethod def validate_api_key_auth_args(cls, args): diff --git a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py index 5eb9f71e695..e55b46d38bf 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/test_containers_integration_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -1,7 +1,7 @@ """Controller integration tests for API key data source auth routes.""" import json -from unittest.mock import patch +from unittest.mock import ANY, patch from flask.testing import FlaskClient from sqlalchemy import select @@ -85,7 +85,7 @@ def test_create_binding_successful( assert response.status_code == 200 assert response.get_json() == {"result": "success"} - create_auth.assert_called_once_with(tenant_id, payload) + create_auth.assert_called_once_with(ANY, tenant_id, payload) def test_create_binding_failure( diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py index c93e61b2bfb..e2f8c8fc703 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -51,7 +51,7 @@ class TestApiKeyAuthService: self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) assert len(result) >= 1 tenant_results = [r for r in result if r.tenant_id == tenant_id] @@ -61,7 +61,7 @@ class TestApiKeyAuthService: def test_get_provider_auth_list_empty( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id ): - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] @@ -74,7 +74,7 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id) tenant_results = [r for r in result if r.tenant_id == tenant_id] assert tenant_results == [] @@ -95,7 +95,7 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) mock_factory.assert_called_once() mock_auth_instance.validate_credentials.assert_called_once() @@ -118,7 +118,7 @@ class TestApiKeyAuthService: mock_auth_instance.validate_credentials.return_value = False mock_factory.return_value = mock_auth_instance - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) db_session_with_containers.expire_all() bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() @@ -142,7 +142,7 @@ class TestApiKeyAuthService: original_key = mock_args["credentials"]["config"]["api_key"] - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args) assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" assert mock_args["credentials"]["config"]["api_key"] != original_key @@ -166,14 +166,14 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result == mock_credentials def test_get_auth_credentials_not_found( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, category, provider ): - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result is None @@ -190,7 +190,7 @@ class TestApiKeyAuthService: ) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider) assert result == special_credentials assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" @@ -204,7 +204,7 @@ class TestApiKeyAuthService: binding_id = binding.id db_session_with_containers.expire_all() - ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, binding_id) db_session_with_containers.expire_all() remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() @@ -214,7 +214,7 @@ class TestApiKeyAuthService: self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id ): # Should not raise when binding not found - ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, str(uuid4())) def test_validate_api_key_auth_args_success(self, mock_args): ApiKeyAuthService.validate_api_key_auth_args(mock_args) @@ -288,16 +288,16 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.return_value = "encrypted_key" - with patch("services.auth.api_key_auth_service.db.session") as mock_session: - mock_session.commit.side_effect = Exception("Database error") - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + mock_session = MagicMock() + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(mock_session, tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args): mock_factory.side_effect = Exception("Factory error") with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args) @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") @patch("services.auth.api_key_auth_service.encrypter") @@ -307,7 +307,7 @@ class TestApiKeyAuthService: mock_factory.return_value = mock_auth_instance mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args) def test_validate_api_key_auth_args_none_input(self): with pytest.raises(TypeError): diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py index 1de9ce38a0b..9b86ab41f2b 100644 --- a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -13,6 +13,7 @@ import pytest from flask import Flask from sqlalchemy.orm import Session +from extensions.ext_database import db from models.source import DataSourceApiKeyAuthBinding from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_service import ApiKeyAuthService @@ -56,7 +57,7 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_fc_test_key_123" args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) mock_http.assert_called_once() call_args = mock_http.call_args @@ -100,15 +101,15 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_key" args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args1) args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_2, args2) db_session_with_containers.expire_all() - result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) - result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + result1 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_2) assert len(result1) == 1 assert result1[0].tenant_id == tenant_id_1 @@ -118,7 +119,9 @@ class TestAuthIntegration: def test_cross_tenant_access_prevention( self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id_2, category ): - result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + result = ApiKeyAuthService.get_auth_credentials( + db_session_with_containers, tenant_id_2, category, AuthType.FIRECRAWL + ) assert result is None @@ -160,7 +163,7 @@ class TestAuthIntegration: "provider": AuthType.FIRECRAWL, "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, } - ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + ApiKeyAuthService.create_provider_auth(db.session(), tenant_id_1, thread_args) results.append("success") except Exception as e: exceptions.append(e) @@ -213,7 +216,7 @@ class TestAuthIntegration: args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) db_session_with_containers.expire_all() bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() @@ -250,11 +253,13 @@ class TestAuthIntegration: mock_encrypt.return_value = "encrypted_key" args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args) db_session_with_containers.expire_all() - result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + result = ApiKeyAuthService.get_auth_credentials( + db_session_with_containers, tenant_id_1, category, AuthType.FIRECRAWL + ) assert result is not None assert result["config"]["api_key"] == "encrypted_key" diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py index 7f449bb376e..21d1932f820 100644 --- a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -3,7 +3,7 @@ from __future__ import annotations from datetime import UTC, datetime from inspect import unwrap from types import SimpleNamespace -from unittest.mock import PropertyMock, patch +from unittest.mock import ANY, PropertyMock, patch from controllers.console import console_ns from controllers.console.auth.data_source_bearer_auth import ( @@ -34,13 +34,16 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None: updated_at=datetime(2026, 1, 2, tzinfo=UTC), ) - with patch( - "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", - return_value=[binding], - ) as get_provider_auth_list: + with ( + patch("controllers.console.auth.data_source_bearer_auth.db"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list", + return_value=[binding], + ) as get_provider_auth_list, + ): result = method(api, "tenant-1") - get_provider_auth_list.assert_called_once_with("tenant-1") + get_provider_auth_list.assert_called_once_with(ANY, "tenant-1") assert result["sources"][0]["id"] == "binding-1" assert result["sources"][0]["provider"] == "custom" @@ -56,12 +59,13 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None: with ( _payload_patch(payload), + patch("controllers.console.auth.data_source_bearer_auth.db"), patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"), patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth, ): result, status = method(api, "tenant-1") - create_auth.assert_called_once_with("tenant-1", payload) + create_auth.assert_called_once_with(ANY, "tenant-1", payload) assert result == {"result": "success"} assert status == 200 @@ -70,11 +74,14 @@ def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None: api = ApiKeyAuthDataSourceBindingDelete() method = unwrap(api.delete) - with patch( - "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" - ) as delete_provider_auth: + with ( + patch("controllers.console.auth.data_source_bearer_auth.db"), + patch( + "controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth" + ) as delete_provider_auth, + ): result, status = method(api, "tenant-1", "binding-1") - delete_provider_auth.assert_called_once_with("tenant-1", "binding-1") + delete_provider_auth.assert_called_once_with(ANY, "tenant-1", "binding-1") assert result == "" assert status == 204