refactor: model_load_balancing_service and api_tools_manage_service (#34434)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Renzo 2026-04-02 06:38:35 +02:00 committed by GitHub
parent f9d9ad7a38
commit 399d3f8da5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 60 deletions

View File

@ -110,20 +110,21 @@ class ModelLoadBalancingService:
credential_source_type = CredentialSourceType.CUSTOM_MODEL credential_source_type = CredentialSourceType.CUSTOM_MODEL
# Get load balancing configurations # Get load balancing configurations
load_balancing_configs = ( load_balancing_configs = list(
db.session.query(LoadBalancingModelConfig) db.session.scalars(
.where( select(LoadBalancingModelConfig)
LoadBalancingModelConfig.tenant_id == tenant_id, .where(
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.model_type == model_type_enum, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.model_type == model_type_enum,
or_( LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.credential_source_type == credential_source_type, or_(
LoadBalancingModelConfig.credential_source_type.is_(None), LoadBalancingModelConfig.credential_source_type == credential_source_type,
), LoadBalancingModelConfig.credential_source_type.is_(None),
) ),
.order_by(LoadBalancingModelConfig.created_at) )
.all() .order_by(LoadBalancingModelConfig.created_at)
).all()
) )
if provider_configuration.custom_configuration.provider: if provider_configuration.custom_configuration.provider:
@ -143,7 +144,7 @@ class ModelLoadBalancingService:
load_balancing_configs.insert(0, inherit_config) load_balancing_configs.insert(0, inherit_config)
else: else:
# move the inherit configuration to the first # move the inherit configuration to the first
for i, load_balancing_config in enumerate(load_balancing_configs[:]): for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
if load_balancing_config.name == "__inherit__": if load_balancing_config.name == "__inherit__":
inherit_config = load_balancing_configs.pop(i) inherit_config = load_balancing_configs.pop(i)
load_balancing_configs.insert(0, inherit_config) load_balancing_configs.insert(0, inherit_config)
@ -235,8 +236,8 @@ class ModelLoadBalancingService:
model_type_enum = ModelType.value_of(model_type) model_type_enum = ModelType.value_of(model_type)
# Get load balancing configurations # Get load balancing configurations
load_balancing_model_config = ( load_balancing_model_config = db.session.scalar(
db.session.query(LoadBalancingModelConfig) select(LoadBalancingModelConfig)
.where( .where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
@ -244,7 +245,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id, LoadBalancingModelConfig.id == config_id,
) )
.first() .limit(1)
) )
if not load_balancing_model_config: if not load_balancing_model_config:
@ -351,26 +352,26 @@ class ModelLoadBalancingService:
if credential_id: if credential_id:
if config_from == "predefined-model": if config_from == "predefined-model":
credential_record = ( credential_record = db.session.scalar(
db.session.query(ProviderCredential) select(ProviderCredential)
.filter_by( .where(
id=credential_id, ProviderCredential.id == credential_id,
tenant_id=tenant_id, ProviderCredential.tenant_id == tenant_id,
provider_name=provider_configuration.provider.provider, ProviderCredential.provider_name == provider_configuration.provider.provider,
) )
.first() .limit(1)
) )
else: else:
credential_record = ( credential_record = db.session.scalar(
db.session.query(ProviderModelCredential) select(ProviderModelCredential)
.filter_by( .where(
id=credential_id, ProviderModelCredential.id == credential_id,
tenant_id=tenant_id, ProviderModelCredential.tenant_id == tenant_id,
provider_name=provider_configuration.provider.provider, ProviderModelCredential.provider_name == provider_configuration.provider.provider,
model_name=model, ProviderModelCredential.model_name == model,
model_type=model_type_enum, ProviderModelCredential.model_type == model_type_enum,
) )
.first() .limit(1)
) )
if not credential_record: if not credential_record:
raise ValueError(f"Provider credential with id {credential_id} not found") raise ValueError(f"Provider credential with id {credential_id} not found")
@ -510,8 +511,8 @@ class ModelLoadBalancingService:
load_balancing_model_config = None load_balancing_model_config = None
if config_id: if config_id:
# Get load balancing config # Get load balancing config
load_balancing_model_config = ( load_balancing_model_config = db.session.scalar(
db.session.query(LoadBalancingModelConfig) select(LoadBalancingModelConfig)
.where( .where(
LoadBalancingModelConfig.tenant_id == tenant_id, LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider, LoadBalancingModelConfig.provider_name == provider,
@ -519,7 +520,7 @@ class ModelLoadBalancingService:
LoadBalancingModelConfig.model_name == model, LoadBalancingModelConfig.model_name == model,
LoadBalancingModelConfig.id == config_id, LoadBalancingModelConfig.id == config_id,
) )
.first() .limit(1)
) )
if not load_balancing_model_config: if not load_balancing_model_config:

View File

@ -124,13 +124,13 @@ class ApiToolManageService:
provider_name = provider_name.strip() provider_name = provider_name.strip()
# check if the provider exists # check if the provider exists
provider = ( provider = db.session.scalar(
db.session.query(ApiToolProvider) select(ApiToolProvider)
.where( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
.first() .limit(1)
) )
if provider is not None: if provider is not None:
@ -215,13 +215,13 @@ class ApiToolManageService:
""" """
list api tool provider tools list api tool provider tools
""" """
provider: ApiToolProvider | None = ( provider: ApiToolProvider | None = db.session.scalar(
db.session.query(ApiToolProvider) select(ApiToolProvider)
.where( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
.first() .limit(1)
) )
if provider is None: if provider is None:
@ -259,13 +259,13 @@ class ApiToolManageService:
provider_name = provider_name.strip() provider_name = provider_name.strip()
# check if the provider exists # check if the provider exists
provider = ( provider = db.session.scalar(
db.session.query(ApiToolProvider) select(ApiToolProvider)
.where( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider, ApiToolProvider.name == original_provider,
) )
.first() .limit(1)
) )
if provider is None: if provider is None:
@ -328,13 +328,13 @@ class ApiToolManageService:
""" """
delete tool provider delete tool provider
""" """
provider = ( provider = db.session.scalar(
db.session.query(ApiToolProvider) select(ApiToolProvider)
.where( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
.first() .limit(1)
) )
if provider is None: if provider is None:
@ -378,13 +378,13 @@ class ApiToolManageService:
if tool_bundle is None: if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}") raise ValueError(f"invalid tool name {tool_name}")
db_provider = ( db_provider = db.session.scalar(
db.session.query(ApiToolProvider) select(ApiToolProvider)
.where( .where(
ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name, ApiToolProvider.name == provider_name,
) )
.first() .limit(1)
) )
if not db_provider: if not db_provider:

View File

@ -158,7 +158,7 @@ def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_fo
credential_id="cred-1", credential_id="cred-1",
enabled=True, enabled=True,
) )
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config] mock_db.session.scalars.return_value.all.return_value = [config]
mocker.patch( mocker.patch(
"services.model_load_balancing_service.encrypter.get_decrypt_decoding", "services.model_load_balancing_service.encrypter.get_decrypt_decoding",
return_value=("rsa", "cipher"), return_value=("rsa", "cipher"),
@ -216,7 +216,7 @@ def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate
credential_id=None, credential_id=None,
enabled=False, enabled=False,
) )
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ mock_db.session.scalars.return_value.all.return_value = [
normal_config, normal_config,
inherit_config, inherit_config,
] ]
@ -269,7 +269,7 @@ def test_get_load_balancing_config_should_return_none_when_config_not_found(
# Arrange # Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None mock_db.session.scalar.return_value = None
# Act # Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
@ -289,7 +289,7 @@ def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_
} }
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True) config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
mock_db.session.query.return_value.where.return_value.first.return_value = config mock_db.session.scalar.return_value = config
# Act # Act
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
@ -389,7 +389,7 @@ def test_update_load_balancing_configs_should_raise_value_error_when_credential_
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = [] mock_db.session.scalars.return_value.all.return_value = []
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None mock_db.session.scalar.return_value = None
# Act + Assert # Act + Assert
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"): with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
@ -578,7 +578,7 @@ def test_update_load_balancing_configs_should_create_from_existing_provider_cred
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.scalars.return_value.all.return_value = [] mock_db.session.scalars.return_value.all.return_value = []
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}') credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record mock_db.session.scalar.return_value = credential_record
# Act # Act
service.update_load_balancing_configs( service.update_load_balancing_configs(
@ -623,7 +623,7 @@ def test_validate_load_balancing_credentials_should_raise_value_error_when_confi
# Arrange # Arrange
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
mock_db.session.query.return_value.where.return_value.first.return_value = None mock_db.session.scalar.return_value = None
# Act + Assert # Act + Assert
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"): with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
@ -646,7 +646,7 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
existing_config = SimpleNamespace(id="cfg-1") existing_config = SimpleNamespace(id="cfg-1")
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config mock_db.session.scalar.return_value = existing_config
mock_validate = mocker.patch.object(service, "_custom_credentials_validate") mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
# Act # Act