mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 13:01:16 +08:00
refactor: accept db.session explicitly in ApiKeyAuthService (#37832)
Co-authored-by: kunalj1-arch <kunal.j1@turing.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
50b3228bc7
commit
d0b2239c60
@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.common.fields import SimpleResultResponse
|
||||
from controllers.common.schema import register_response_schema_models, register_schema_models
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -58,7 +59,7 @@ class ApiKeyAuthDataSource(Resource):
|
||||
@account_initialization_required
|
||||
@with_current_tenant_id
|
||||
def get(self, current_tenant_id: str):
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_tenant_id)
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(db.session(), current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
"sources": [
|
||||
@ -92,7 +93,7 @@ class ApiKeyAuthDataSourceBinding(Resource):
|
||||
data = payload.model_dump()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(data)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_tenant_id, data)
|
||||
ApiKeyAuthService.create_provider_auth(db.session(), current_tenant_id, data)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {"result": "success"}, 200
|
||||
@ -109,6 +110,6 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@with_current_tenant_id
|
||||
def delete(self, current_tenant_id: str, binding_id: UUID):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
ApiKeyAuthService.delete_provider_auth(current_tenant_id, str(binding_id))
|
||||
ApiKeyAuthService.delete_provider_auth(db.session(), current_tenant_id, str(binding_id))
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -2,17 +2,17 @@ import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str):
|
||||
data_source_api_key_bindings = db.session.scalars(
|
||||
def get_provider_auth_list(session: Session, tenant_id: str):
|
||||
data_source_api_key_bindings = session.scalars(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
)
|
||||
@ -20,7 +20,7 @@ class ApiKeyAuthService:
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict[str, Any]):
|
||||
def create_provider_auth(session: Session, tenant_id: str, args: dict[str, Any]):
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
@ -31,12 +31,12 @@ class ApiKeyAuthService:
|
||||
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
|
||||
)
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
session.add(data_source_api_key_binding)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.scalar(
|
||||
def get_auth_credentials(session: Session, tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
@ -52,16 +52,16 @@ class ApiKeyAuthService:
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.scalar(
|
||||
def delete_provider_auth(session: Session, tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id,
|
||||
)
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
session.delete(data_source_api_key_binding)
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Controller integration tests for API key data source auth routes."""
|
||||
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
from sqlalchemy import select
|
||||
@ -85,7 +85,7 @@ def test_create_binding_successful(
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"result": "success"}
|
||||
create_auth.assert_called_once_with(tenant_id, payload)
|
||||
create_auth.assert_called_once_with(ANY, tenant_id, payload)
|
||||
|
||||
|
||||
def test_create_binding_failure(
|
||||
|
||||
@ -51,7 +51,7 @@ class TestApiKeyAuthService:
|
||||
self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
|
||||
|
||||
assert len(result) >= 1
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
@ -61,7 +61,7 @@ class TestApiKeyAuthService:
|
||||
def test_get_provider_auth_list_empty(
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id
|
||||
):
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
|
||||
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
assert tenant_results == []
|
||||
@ -74,7 +74,7 @@ class TestApiKeyAuthService:
|
||||
)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_provider_auth_list(tenant_id)
|
||||
result = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id)
|
||||
|
||||
tenant_results = [r for r in result if r.tenant_id == tenant_id]
|
||||
assert tenant_results == []
|
||||
@ -95,7 +95,7 @@ class TestApiKeyAuthService:
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123"
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
|
||||
|
||||
mock_factory.assert_called_once()
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
@ -118,7 +118,7 @@ class TestApiKeyAuthService:
|
||||
mock_auth_instance.validate_credentials.return_value = False
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all()
|
||||
@ -142,7 +142,7 @@ class TestApiKeyAuthService:
|
||||
|
||||
original_key = mock_args["credentials"]["config"]["api_key"]
|
||||
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id, mock_args)
|
||||
|
||||
assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123"
|
||||
assert mock_args["credentials"]["config"]["api_key"] != original_key
|
||||
@ -166,14 +166,14 @@ class TestApiKeyAuthService:
|
||||
)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
|
||||
|
||||
assert result == mock_credentials
|
||||
|
||||
def test_get_auth_credentials_not_found(
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id, category, provider
|
||||
):
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -190,7 +190,7 @@ class TestApiKeyAuthService:
|
||||
)
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider)
|
||||
result = ApiKeyAuthService.get_auth_credentials(db_session_with_containers, tenant_id, category, provider)
|
||||
|
||||
assert result == special_credentials
|
||||
assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%"
|
||||
@ -204,7 +204,7 @@ class TestApiKeyAuthService:
|
||||
binding_id = binding.id
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id)
|
||||
ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, binding_id)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first()
|
||||
@ -214,7 +214,7 @@ class TestApiKeyAuthService:
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id
|
||||
):
|
||||
# Should not raise when binding not found
|
||||
ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4()))
|
||||
ApiKeyAuthService.delete_provider_auth(db_session_with_containers, tenant_id, str(uuid4()))
|
||||
|
||||
def test_validate_api_key_auth_args_success(self, mock_args):
|
||||
ApiKeyAuthService.validate_api_key_auth_args(mock_args)
|
||||
@ -288,16 +288,16 @@ class TestApiKeyAuthService:
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
mock_encrypter.encrypt_token.return_value = "encrypted_key"
|
||||
|
||||
with patch("services.auth.api_key_auth_service.db.session") as mock_session:
|
||||
mock_session.commit.side_effect = Exception("Database error")
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
mock_session = MagicMock()
|
||||
mock_session.commit.side_effect = Exception("Database error")
|
||||
with pytest.raises(Exception, match="Database error"):
|
||||
ApiKeyAuthService.create_provider_auth(mock_session, tenant_id, mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
def test_create_provider_auth_factory_exception(self, mock_factory: MagicMock, tenant_id, mock_args):
|
||||
mock_factory.side_effect = Exception("Factory error")
|
||||
with pytest.raises(Exception, match="Factory error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args)
|
||||
|
||||
@patch("services.auth.api_key_auth_service.ApiKeyAuthFactory")
|
||||
@patch("services.auth.api_key_auth_service.encrypter")
|
||||
@ -307,7 +307,7 @@ class TestApiKeyAuthService:
|
||||
mock_factory.return_value = mock_auth_instance
|
||||
mock_encrypter.encrypt_token.side_effect = Exception("Encryption error")
|
||||
with pytest.raises(Exception, match="Encryption error"):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id, mock_args)
|
||||
ApiKeyAuthService.create_provider_auth(MagicMock(), tenant_id, mock_args)
|
||||
|
||||
def test_validate_api_key_auth_args_none_input(self):
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from flask import Flask
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
@ -56,7 +57,7 @@ class TestAuthIntegration:
|
||||
mock_encrypt.return_value = "encrypted_fc_test_key_123"
|
||||
|
||||
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args)
|
||||
|
||||
mock_http.assert_called_once()
|
||||
call_args = mock_http.call_args
|
||||
@ -100,15 +101,15 @@ class TestAuthIntegration:
|
||||
mock_encrypt.return_value = "encrypted_key"
|
||||
|
||||
args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id_1, args1)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args1)
|
||||
|
||||
args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials}
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id_2, args2)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_2, args2)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1)
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2)
|
||||
result1 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_1)
|
||||
result2 = ApiKeyAuthService.get_provider_auth_list(db_session_with_containers, tenant_id_2)
|
||||
|
||||
assert len(result1) == 1
|
||||
assert result1[0].tenant_id == tenant_id_1
|
||||
@ -118,7 +119,9 @@ class TestAuthIntegration:
|
||||
def test_cross_tenant_access_prevention(
|
||||
self, flask_app_with_containers: Flask, db_session_with_containers: Session, tenant_id_2, category
|
||||
):
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL)
|
||||
result = ApiKeyAuthService.get_auth_credentials(
|
||||
db_session_with_containers, tenant_id_2, category, AuthType.FIRECRAWL
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@ -160,7 +163,7 @@ class TestAuthIntegration:
|
||||
"provider": AuthType.FIRECRAWL,
|
||||
"credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}},
|
||||
}
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args)
|
||||
ApiKeyAuthService.create_provider_auth(db.session(), tenant_id_1, thread_args)
|
||||
results.append("success")
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
@ -213,7 +216,7 @@ class TestAuthIntegration:
|
||||
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
|
||||
|
||||
with pytest.raises(httpx.RequestError):
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all()
|
||||
@ -250,11 +253,13 @@ class TestAuthIntegration:
|
||||
mock_encrypt.return_value = "encrypted_key"
|
||||
|
||||
args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials}
|
||||
ApiKeyAuthService.create_provider_auth(tenant_id_1, args)
|
||||
ApiKeyAuthService.create_provider_auth(db_session_with_containers, tenant_id_1, args)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
|
||||
result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL)
|
||||
result = ApiKeyAuthService.get_auth_credentials(
|
||||
db_session_with_containers, tenant_id_1, category, AuthType.FIRECRAWL
|
||||
)
|
||||
assert result is not None
|
||||
assert result["config"]["api_key"] == "encrypted_key"
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from datetime import UTC, datetime
|
||||
from inspect import unwrap
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import PropertyMock, patch
|
||||
from unittest.mock import ANY, PropertyMock, patch
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.data_source_bearer_auth import (
|
||||
@ -34,13 +34,16 @@ def test_list_data_source_auth_uses_injected_tenant_id() -> None:
|
||||
updated_at=datetime(2026, 1, 2, tzinfo=UTC),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
|
||||
return_value=[binding],
|
||||
) as get_provider_auth_list:
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.db"),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list",
|
||||
return_value=[binding],
|
||||
) as get_provider_auth_list,
|
||||
):
|
||||
result = method(api, "tenant-1")
|
||||
|
||||
get_provider_auth_list.assert_called_once_with("tenant-1")
|
||||
get_provider_auth_list.assert_called_once_with(ANY, "tenant-1")
|
||||
assert result["sources"][0]["id"] == "binding-1"
|
||||
assert result["sources"][0]["provider"] == "custom"
|
||||
|
||||
@ -56,12 +59,13 @@ def test_create_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
|
||||
with (
|
||||
_payload_patch(payload),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.db"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args"),
|
||||
patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") as create_auth,
|
||||
):
|
||||
result, status = method(api, "tenant-1")
|
||||
|
||||
create_auth.assert_called_once_with("tenant-1", payload)
|
||||
create_auth.assert_called_once_with(ANY, "tenant-1", payload)
|
||||
assert result == {"result": "success"}
|
||||
assert status == 200
|
||||
|
||||
@ -70,11 +74,14 @@ def test_delete_data_source_auth_binding_uses_injected_tenant_id() -> None:
|
||||
api = ApiKeyAuthDataSourceBindingDelete()
|
||||
method = unwrap(api.delete)
|
||||
|
||||
with patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
|
||||
) as delete_provider_auth:
|
||||
with (
|
||||
patch("controllers.console.auth.data_source_bearer_auth.db"),
|
||||
patch(
|
||||
"controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth"
|
||||
) as delete_provider_auth,
|
||||
):
|
||||
result, status = method(api, "tenant-1", "binding-1")
|
||||
|
||||
delete_provider_auth.assert_called_once_with("tenant-1", "binding-1")
|
||||
delete_provider_auth.assert_called_once_with(ANY, "tenant-1", "binding-1")
|
||||
assert result == ""
|
||||
assert status == 204
|
||||
|
||||
Loading…
Reference in New Issue
Block a user