mirror of
https://github.com/langgenius/dify.git
synced 2026-06-22 19:21:13 +08:00
refactor: accept db.session explicitly in APIBasedExtensionService (#37693)
This commit is contained in:
parent
75d50455d6
commit
9b4dd9d4e8
@ -7,6 +7,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, TypeAdapter, field_validator
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import to_timestamp
|
||||
from libs.login import login_required
|
||||
@ -126,7 +127,7 @@ class APIBasedExtensionAPI(Resource):
|
||||
def get(self, current_tenant_id: str):
|
||||
return [
|
||||
_serialize_api_based_extension(extension)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(current_tenant_id)
|
||||
for extension in APIBasedExtensionService.get_all_by_tenant_id(db.session(), current_tenant_id)
|
||||
]
|
||||
|
||||
@console_ns.doc("create_api_based_extension")
|
||||
@ -147,7 +148,12 @@ class APIBasedExtensionAPI(Resource):
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
return _serialize_saved_api_based_extension(APIBasedExtensionService.save(extension_data), payload.api_key), 201
|
||||
return (
|
||||
_serialize_saved_api_based_extension(
|
||||
APIBasedExtensionService.save(db.session(), extension_data), payload.api_key
|
||||
),
|
||||
201,
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/api-based-extension/<uuid:id>")
|
||||
@ -164,7 +170,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
api_based_extension_id = str(id)
|
||||
|
||||
return _serialize_api_based_extension(
|
||||
APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
APIBasedExtensionService.get_with_tenant_id(db.session(), current_tenant_id, api_based_extension_id)
|
||||
)
|
||||
|
||||
@console_ns.doc("update_api_based_extension")
|
||||
@ -179,7 +185,9 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
def post(self, current_tenant_id: str, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(
|
||||
db.session(), current_tenant_id, api_based_extension_id
|
||||
)
|
||||
|
||||
payload = APIBasedExtensionPayload.model_validate(console_ns.payload or {})
|
||||
api_key_for_response = extension_data_from_db.api_key
|
||||
@ -192,7 +200,7 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
api_key_for_response = payload.api_key
|
||||
|
||||
return _serialize_saved_api_based_extension(
|
||||
APIBasedExtensionService.save(extension_data_from_db),
|
||||
APIBasedExtensionService.save(db.session(), extension_data_from_db),
|
||||
api_key_for_response,
|
||||
)
|
||||
|
||||
@ -207,8 +215,10 @@ class APIBasedExtensionDetailAPI(Resource):
|
||||
def delete(self, current_tenant_id: str, id: UUID):
|
||||
api_based_extension_id = str(id)
|
||||
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(current_tenant_id, api_based_extension_id)
|
||||
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(
|
||||
db.session(), current_tenant_id, api_based_extension_id
|
||||
)
|
||||
|
||||
APIBasedExtensionService.delete(extension_data_from_db)
|
||||
APIBasedExtensionService.delete(db.session(), extension_data_from_db)
|
||||
|
||||
return "", 204
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.helper.encrypter import decrypt_token, encrypt_token
|
||||
from extensions.ext_database import db
|
||||
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionService:
|
||||
@staticmethod
|
||||
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
|
||||
def get_all_by_tenant_id(session: Session, tenant_id: str) -> list[APIBasedExtension]:
|
||||
extension_list = list(
|
||||
db.session.scalars(
|
||||
session.scalars(
|
||||
select(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id)
|
||||
.order_by(APIBasedExtension.created_at.desc())
|
||||
@ -23,23 +23,23 @@ class APIBasedExtensionService:
|
||||
return extension_list
|
||||
|
||||
@classmethod
|
||||
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
|
||||
cls._validation(extension_data)
|
||||
def save(cls, session: Session, extension_data: APIBasedExtension) -> APIBasedExtension:
|
||||
cls._validation(session, extension_data)
|
||||
|
||||
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
|
||||
|
||||
db.session.add(extension_data)
|
||||
db.session.commit()
|
||||
session.add(extension_data)
|
||||
session.commit()
|
||||
return extension_data
|
||||
|
||||
@staticmethod
|
||||
def delete(extension_data: APIBasedExtension):
|
||||
db.session.delete(extension_data)
|
||||
db.session.commit()
|
||||
def delete(session: Session, extension_data: APIBasedExtension):
|
||||
session.delete(extension_data)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = db.session.scalar(
|
||||
def get_with_tenant_id(session: Session, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.limit(1)
|
||||
@ -53,14 +53,14 @@ class APIBasedExtensionService:
|
||||
return extension
|
||||
|
||||
@classmethod
|
||||
def _validation(cls, extension_data: APIBasedExtension):
|
||||
def _validation(cls, session: Session, extension_data: APIBasedExtension):
|
||||
# name
|
||||
if not extension_data.name:
|
||||
raise ValueError("name must not be empty")
|
||||
|
||||
if not extension_data.id:
|
||||
# case one: check new data, name must be unique
|
||||
is_name_existed = db.session.scalar(
|
||||
is_name_existed = session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(
|
||||
APIBasedExtension.tenant_id == extension_data.tenant_id,
|
||||
@ -73,7 +73,7 @@ class APIBasedExtensionService:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
else:
|
||||
# case two: check existing data, name must be unique
|
||||
is_name_existed = db.session.scalar(
|
||||
is_name_existed = session.scalar(
|
||||
select(APIBasedExtension)
|
||||
.where(
|
||||
APIBasedExtension.tenant_id == extension_data.tenant_id,
|
||||
|
||||
@ -97,12 +97,13 @@ def test_list_scopes_api_based_extensions_to_authenticated_tenant(
|
||||
assert account_create_response.status_code == 201
|
||||
|
||||
APIBasedExtensionService.save(
|
||||
db_session_with_containers,
|
||||
APIBasedExtension(
|
||||
tenant_id=foreign_tenant_id,
|
||||
name="Foreign API",
|
||||
api_endpoint="https://foreign.example.com/hook",
|
||||
api_key="foreign-secret-12345",
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = test_client_with_containers.get(
|
||||
|
||||
@ -81,7 +81,7 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
# Save extension
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Verify extension was saved correctly
|
||||
assert saved_extension.id is not None
|
||||
@ -119,21 +119,21 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Test empty api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Test empty api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = ""
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_success(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -157,11 +157,11 @@ class TestAPIBasedExtensionService:
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
saved_extension = APIBasedExtensionService.save(extension_data)
|
||||
saved_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
extensions.append(saved_extension)
|
||||
|
||||
# Get all extensions for tenant
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id)
|
||||
|
||||
# Verify results
|
||||
assert len(extension_list) == 3
|
||||
@ -191,10 +191,12 @@ class TestAPIBasedExtensionService:
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Get extension by ID
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(
|
||||
db_session_with_containers, tenant.id, created_extension.id
|
||||
)
|
||||
|
||||
# Verify extension was retrieved correctly
|
||||
assert retrieved_extension is not None
|
||||
@ -219,7 +221,9 @@ class TestAPIBasedExtensionService:
|
||||
|
||||
# Try to get non-existent extension
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id)
|
||||
APIBasedExtensionService.get_with_tenant_id(
|
||||
db_session_with_containers, tenant.id, non_existent_extension_id
|
||||
)
|
||||
|
||||
def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
@ -238,11 +242,11 @@ class TestAPIBasedExtensionService:
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
extension_id = created_extension.id
|
||||
|
||||
# Delete the extension
|
||||
APIBasedExtensionService.delete(created_extension)
|
||||
APIBasedExtensionService.delete(db_session_with_containers, created_extension)
|
||||
|
||||
# Verify extension was deleted
|
||||
|
||||
@ -270,7 +274,7 @@ class TestAPIBasedExtensionService:
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
APIBasedExtensionService.save(extension_data1)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data1)
|
||||
# Try to create second extension with same name
|
||||
extension_data2 = APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
@ -280,7 +284,7 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(extension_data2)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data2)
|
||||
|
||||
def test_save_extension_update_existing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -301,7 +305,7 @@ class TestAPIBasedExtensionService:
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Save original values for later comparison
|
||||
original_name = created_extension.name
|
||||
@ -320,7 +324,7 @@ class TestAPIBasedExtensionService:
|
||||
created_extension.api_endpoint = new_endpoint
|
||||
created_extension.api_key = new_api_key
|
||||
|
||||
updated_extension = APIBasedExtensionService.save(created_extension)
|
||||
updated_extension = APIBasedExtensionService.save(db_session_with_containers, created_extension)
|
||||
|
||||
# Verify extension was updated correctly
|
||||
assert updated_extension.id == created_extension.id
|
||||
@ -336,7 +340,9 @@ class TestAPIBasedExtensionService:
|
||||
assert mock_external_service_dependencies["requestor_instance"].request.call_count == 2
|
||||
|
||||
# Verify the update by retrieving the extension again
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(tenant.id, created_extension.id)
|
||||
retrieved_extension = APIBasedExtensionService.get_with_tenant_id(
|
||||
db_session_with_containers, tenant.id, created_extension.id
|
||||
)
|
||||
assert retrieved_extension.name == new_name
|
||||
assert retrieved_extension.api_endpoint == new_endpoint
|
||||
assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved
|
||||
@ -367,7 +373,7 @@ class TestAPIBasedExtensionService:
|
||||
|
||||
# Try to save extension with connection error
|
||||
with pytest.raises(ValueError, match="connection error: request timeout"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_save_extension_invalid_api_key_length(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -390,7 +396,7 @@ class TestAPIBasedExtensionService:
|
||||
|
||||
# Try to save extension with short API key
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies):
|
||||
"""
|
||||
@ -410,21 +416,21 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="name must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Test with None api_endpoint
|
||||
extension_data.name = fake.company()
|
||||
extension_data.api_endpoint = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_endpoint must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Test with None api_key
|
||||
extension_data.api_endpoint = f"https://{fake.domain_name()}/api"
|
||||
extension_data.api_key = None
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must not be empty"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_get_all_by_tenant_id_empty_list(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -438,7 +444,7 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
# Get all extensions for tenant (none exist)
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(tenant.id)
|
||||
extension_list = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant.id)
|
||||
|
||||
# Verify empty list is returned
|
||||
assert len(extension_list) == 0
|
||||
@ -468,7 +474,7 @@ class TestAPIBasedExtensionService:
|
||||
|
||||
# Try to save extension with invalid ping response
|
||||
with pytest.raises(ValueError, match="{'result': 'invalid'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_save_extension_missing_ping_result(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -494,7 +500,7 @@ class TestAPIBasedExtensionService:
|
||||
|
||||
# Try to save extension with missing ping result
|
||||
with pytest.raises(ValueError, match="{'status': 'ok'}"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_get_with_tenant_id_wrong_tenant(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -520,11 +526,11 @@ class TestAPIBasedExtensionService:
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
|
||||
created_extension = APIBasedExtensionService.save(extension_data)
|
||||
created_extension = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
# Try to get extension with wrong tenant ID
|
||||
with pytest.raises(ValueError, match="API based extension is not found"):
|
||||
APIBasedExtensionService.get_with_tenant_id(tenant2.id, created_extension.id)
|
||||
APIBasedExtensionService.get_with_tenant_id(db_session_with_containers, tenant2.id, created_extension.id)
|
||||
|
||||
def test_save_extension_api_key_exactly_four_chars_rejected(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -544,7 +550,7 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="api_key must be at least 5 characters"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_save_extension_api_key_exactly_five_chars_accepted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -563,7 +569,7 @@ class TestAPIBasedExtensionService:
|
||||
api_key="12345",
|
||||
)
|
||||
|
||||
saved = APIBasedExtensionService.save(extension_data)
|
||||
saved = APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
assert saved.id is not None
|
||||
|
||||
def test_save_extension_requestor_constructor_error(
|
||||
@ -586,7 +592,7 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection error: bad config"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_save_extension_network_exception(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -610,7 +616,7 @@ class TestAPIBasedExtensionService:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="connection error: network failure"):
|
||||
APIBasedExtensionService.save(extension_data)
|
||||
APIBasedExtensionService.save(db_session_with_containers, extension_data)
|
||||
|
||||
def test_save_extension_update_duplicate_name_rejected(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -623,26 +629,28 @@ class TestAPIBasedExtensionService:
|
||||
assert tenant is not None
|
||||
|
||||
ext1 = APIBasedExtensionService.save(
|
||||
db_session_with_containers,
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Extension Alpha",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
),
|
||||
)
|
||||
ext2 = APIBasedExtensionService.save(
|
||||
db_session_with_containers,
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant.id,
|
||||
name="Extension Beta",
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Try to rename ext2 to ext1's name
|
||||
ext2.name = "Extension Alpha"
|
||||
with pytest.raises(ValueError, match="name must be unique, it is already existed"):
|
||||
APIBasedExtensionService.save(ext2)
|
||||
APIBasedExtensionService.save(db_session_with_containers, ext2)
|
||||
|
||||
def test_get_all_returns_empty_for_different_tenant(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
@ -658,14 +666,15 @@ class TestAPIBasedExtensionService:
|
||||
assert tenant1 is not None
|
||||
|
||||
APIBasedExtensionService.save(
|
||||
db_session_with_containers,
|
||||
APIBasedExtension(
|
||||
tenant_id=tenant1.id,
|
||||
name=fake.company(),
|
||||
api_endpoint=f"https://{fake.domain_name()}/api",
|
||||
api_key=fake.password(length=20),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
assert tenant2 is not None
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id(tenant2.id)
|
||||
result = APIBasedExtensionService.get_all_by_tenant_id(db_session_with_containers, tenant2.id)
|
||||
assert result == []
|
||||
|
||||
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
@ -114,7 +114,7 @@ def test_api_based_extension_get_returns_tenant_extensions(app: Flask, monkeypat
|
||||
assert response[0]["name"] == "Weather API"
|
||||
assert response[0]["api_endpoint"] == extension.api_endpoint
|
||||
assert response[0]["api_key"].startswith(extension.api_key[:3])
|
||||
service_mock.assert_called_once_with("tenant-123")
|
||||
service_mock.assert_called_once_with(ANY, "tenant-123")
|
||||
|
||||
|
||||
def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -132,7 +132,7 @@ def test_api_based_extension_post_creates_extension(app: Flask, monkeypatch: pyt
|
||||
response, status = APIBasedExtensionAPI().post()
|
||||
|
||||
args, _ = save_mock.call_args
|
||||
created_extension: APIBasedExtension = args[0]
|
||||
created_extension: APIBasedExtension = args[1]
|
||||
assert created_extension.tenant_id == "tenant-123"
|
||||
assert created_extension.name == payload["name"]
|
||||
assert created_extension.api_endpoint == payload["api_endpoint"]
|
||||
@ -157,7 +157,7 @@ def test_api_based_extension_detail_get_fetches_extension(app: Flask, monkeypatc
|
||||
|
||||
assert response["id"] == extension.id
|
||||
assert response["name"] == extension.name
|
||||
service_mock.assert_called_once_with("tenant-123", str(extension_id))
|
||||
service_mock.assert_called_once_with(ANY, "tenant-123", str(extension_id))
|
||||
|
||||
|
||||
def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkeypatch: pytest.MonkeyPatch):
|
||||
@ -187,7 +187,7 @@ def test_api_based_extension_detail_post_keeps_hidden_api_key(app: Flask, monkey
|
||||
assert existing_extension.name == payload["name"]
|
||||
assert existing_extension.api_endpoint == payload["api_endpoint"]
|
||||
assert existing_extension.api_key == "keep-me"
|
||||
save_mock.assert_called_once_with(existing_extension)
|
||||
save_mock.assert_called_once_with(ANY, existing_extension)
|
||||
assert response["name"] == payload["name"]
|
||||
assert response["api_key"] == _masked_api_key("keep-me")
|
||||
|
||||
@ -217,7 +217,7 @@ def test_api_based_extension_detail_post_updates_api_key_when_provided(app: Flas
|
||||
response = APIBasedExtensionDetailAPI().post(extension_id)
|
||||
|
||||
assert existing_extension.api_key == "new-secret"
|
||||
save_mock.assert_called_once_with(existing_extension)
|
||||
save_mock.assert_called_once_with(ANY, existing_extension)
|
||||
assert response["name"] == payload["name"]
|
||||
assert response["api_key"] == _masked_api_key(payload["api_key"])
|
||||
|
||||
@ -239,6 +239,6 @@ def test_api_based_extension_detail_delete_removes_extension(app: Flask, monkeyp
|
||||
):
|
||||
response, status = APIBasedExtensionDetailAPI().delete(extension_id)
|
||||
|
||||
delete_mock.assert_called_once_with(existing_extension)
|
||||
delete_mock.assert_called_once_with(ANY, existing_extension)
|
||||
assert status == 204
|
||||
assert response == ""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user