mirror of
https://github.com/langgenius/dify.git
synced 2026-04-30 13:37:24 +08:00
refactor: model_load_balancing_service and api_tools_manage_service (#34434)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f9d9ad7a38
commit
399d3f8da5
@ -110,20 +110,21 @@ class ModelLoadBalancingService:
|
|||||||
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
credential_source_type = CredentialSourceType.CUSTOM_MODEL
|
||||||
|
|
||||||
# Get load balancing configurations
|
# Get load balancing configurations
|
||||||
load_balancing_configs = (
|
load_balancing_configs = list(
|
||||||
db.session.query(LoadBalancingModelConfig)
|
db.session.scalars(
|
||||||
.where(
|
select(LoadBalancingModelConfig)
|
||||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
.where(
|
||||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||||
LoadBalancingModelConfig.model_type == model_type_enum,
|
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||||
LoadBalancingModelConfig.model_name == model,
|
LoadBalancingModelConfig.model_type == model_type_enum,
|
||||||
or_(
|
LoadBalancingModelConfig.model_name == model,
|
||||||
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
or_(
|
||||||
LoadBalancingModelConfig.credential_source_type.is_(None),
|
LoadBalancingModelConfig.credential_source_type == credential_source_type,
|
||||||
),
|
LoadBalancingModelConfig.credential_source_type.is_(None),
|
||||||
)
|
),
|
||||||
.order_by(LoadBalancingModelConfig.created_at)
|
)
|
||||||
.all()
|
.order_by(LoadBalancingModelConfig.created_at)
|
||||||
|
).all()
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider_configuration.custom_configuration.provider:
|
if provider_configuration.custom_configuration.provider:
|
||||||
@ -143,7 +144,7 @@ class ModelLoadBalancingService:
|
|||||||
load_balancing_configs.insert(0, inherit_config)
|
load_balancing_configs.insert(0, inherit_config)
|
||||||
else:
|
else:
|
||||||
# move the inherit configuration to the first
|
# move the inherit configuration to the first
|
||||||
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
|
for i, load_balancing_config in enumerate(load_balancing_configs.copy()):
|
||||||
if load_balancing_config.name == "__inherit__":
|
if load_balancing_config.name == "__inherit__":
|
||||||
inherit_config = load_balancing_configs.pop(i)
|
inherit_config = load_balancing_configs.pop(i)
|
||||||
load_balancing_configs.insert(0, inherit_config)
|
load_balancing_configs.insert(0, inherit_config)
|
||||||
@ -235,8 +236,8 @@ class ModelLoadBalancingService:
|
|||||||
model_type_enum = ModelType.value_of(model_type)
|
model_type_enum = ModelType.value_of(model_type)
|
||||||
|
|
||||||
# Get load balancing configurations
|
# Get load balancing configurations
|
||||||
load_balancing_model_config = (
|
load_balancing_model_config = db.session.scalar(
|
||||||
db.session.query(LoadBalancingModelConfig)
|
select(LoadBalancingModelConfig)
|
||||||
.where(
|
.where(
|
||||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||||
@ -244,7 +245,7 @@ class ModelLoadBalancingService:
|
|||||||
LoadBalancingModelConfig.model_name == model,
|
LoadBalancingModelConfig.model_name == model,
|
||||||
LoadBalancingModelConfig.id == config_id,
|
LoadBalancingModelConfig.id == config_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not load_balancing_model_config:
|
if not load_balancing_model_config:
|
||||||
@ -351,26 +352,26 @@ class ModelLoadBalancingService:
|
|||||||
|
|
||||||
if credential_id:
|
if credential_id:
|
||||||
if config_from == "predefined-model":
|
if config_from == "predefined-model":
|
||||||
credential_record = (
|
credential_record = db.session.scalar(
|
||||||
db.session.query(ProviderCredential)
|
select(ProviderCredential)
|
||||||
.filter_by(
|
.where(
|
||||||
id=credential_id,
|
ProviderCredential.id == credential_id,
|
||||||
tenant_id=tenant_id,
|
ProviderCredential.tenant_id == tenant_id,
|
||||||
provider_name=provider_configuration.provider.provider,
|
ProviderCredential.provider_name == provider_configuration.provider.provider,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
credential_record = (
|
credential_record = db.session.scalar(
|
||||||
db.session.query(ProviderModelCredential)
|
select(ProviderModelCredential)
|
||||||
.filter_by(
|
.where(
|
||||||
id=credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
tenant_id=tenant_id,
|
ProviderModelCredential.tenant_id == tenant_id,
|
||||||
provider_name=provider_configuration.provider.provider,
|
ProviderModelCredential.provider_name == provider_configuration.provider.provider,
|
||||||
model_name=model,
|
ProviderModelCredential.model_name == model,
|
||||||
model_type=model_type_enum,
|
ProviderModelCredential.model_type == model_type_enum,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not credential_record:
|
if not credential_record:
|
||||||
raise ValueError(f"Provider credential with id {credential_id} not found")
|
raise ValueError(f"Provider credential with id {credential_id} not found")
|
||||||
@ -510,8 +511,8 @@ class ModelLoadBalancingService:
|
|||||||
load_balancing_model_config = None
|
load_balancing_model_config = None
|
||||||
if config_id:
|
if config_id:
|
||||||
# Get load balancing config
|
# Get load balancing config
|
||||||
load_balancing_model_config = (
|
load_balancing_model_config = db.session.scalar(
|
||||||
db.session.query(LoadBalancingModelConfig)
|
select(LoadBalancingModelConfig)
|
||||||
.where(
|
.where(
|
||||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name == provider,
|
LoadBalancingModelConfig.provider_name == provider,
|
||||||
@ -519,7 +520,7 @@ class ModelLoadBalancingService:
|
|||||||
LoadBalancingModelConfig.model_name == model,
|
LoadBalancingModelConfig.model_name == model,
|
||||||
LoadBalancingModelConfig.id == config_id,
|
LoadBalancingModelConfig.id == config_id,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not load_balancing_model_config:
|
if not load_balancing_model_config:
|
||||||
|
|||||||
@ -124,13 +124,13 @@ class ApiToolManageService:
|
|||||||
provider_name = provider_name.strip()
|
provider_name = provider_name.strip()
|
||||||
|
|
||||||
# check if the provider exists
|
# check if the provider exists
|
||||||
provider = (
|
provider = db.session.scalar(
|
||||||
db.session.query(ApiToolProvider)
|
select(ApiToolProvider)
|
||||||
.where(
|
.where(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.name == provider_name,
|
ApiToolProvider.name == provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider is not None:
|
if provider is not None:
|
||||||
@ -215,13 +215,13 @@ class ApiToolManageService:
|
|||||||
"""
|
"""
|
||||||
list api tool provider tools
|
list api tool provider tools
|
||||||
"""
|
"""
|
||||||
provider: ApiToolProvider | None = (
|
provider: ApiToolProvider | None = db.session.scalar(
|
||||||
db.session.query(ApiToolProvider)
|
select(ApiToolProvider)
|
||||||
.where(
|
.where(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.name == provider_name,
|
ApiToolProvider.name == provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider is None:
|
if provider is None:
|
||||||
@ -259,13 +259,13 @@ class ApiToolManageService:
|
|||||||
provider_name = provider_name.strip()
|
provider_name = provider_name.strip()
|
||||||
|
|
||||||
# check if the provider exists
|
# check if the provider exists
|
||||||
provider = (
|
provider = db.session.scalar(
|
||||||
db.session.query(ApiToolProvider)
|
select(ApiToolProvider)
|
||||||
.where(
|
.where(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.name == original_provider,
|
ApiToolProvider.name == original_provider,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider is None:
|
if provider is None:
|
||||||
@ -328,13 +328,13 @@ class ApiToolManageService:
|
|||||||
"""
|
"""
|
||||||
delete tool provider
|
delete tool provider
|
||||||
"""
|
"""
|
||||||
provider = (
|
provider = db.session.scalar(
|
||||||
db.session.query(ApiToolProvider)
|
select(ApiToolProvider)
|
||||||
.where(
|
.where(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.name == provider_name,
|
ApiToolProvider.name == provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if provider is None:
|
if provider is None:
|
||||||
@ -378,13 +378,13 @@ class ApiToolManageService:
|
|||||||
if tool_bundle is None:
|
if tool_bundle is None:
|
||||||
raise ValueError(f"invalid tool name {tool_name}")
|
raise ValueError(f"invalid tool name {tool_name}")
|
||||||
|
|
||||||
db_provider = (
|
db_provider = db.session.scalar(
|
||||||
db.session.query(ApiToolProvider)
|
select(ApiToolProvider)
|
||||||
.where(
|
.where(
|
||||||
ApiToolProvider.tenant_id == tenant_id,
|
ApiToolProvider.tenant_id == tenant_id,
|
||||||
ApiToolProvider.name == provider_name,
|
ApiToolProvider.name == provider_name,
|
||||||
)
|
)
|
||||||
.first()
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not db_provider:
|
if not db_provider:
|
||||||
|
|||||||
@ -158,7 +158,7 @@ def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_fo
|
|||||||
credential_id="cred-1",
|
credential_id="cred-1",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config]
|
mock_db.session.scalars.return_value.all.return_value = [config]
|
||||||
mocker.patch(
|
mocker.patch(
|
||||||
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
"services.model_load_balancing_service.encrypter.get_decrypt_decoding",
|
||||||
return_value=("rsa", "cipher"),
|
return_value=("rsa", "cipher"),
|
||||||
@ -216,7 +216,7 @@ def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate
|
|||||||
credential_id=None,
|
credential_id=None,
|
||||||
enabled=False,
|
enabled=False,
|
||||||
)
|
)
|
||||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
|
mock_db.session.scalars.return_value.all.return_value = [
|
||||||
normal_config,
|
normal_config,
|
||||||
inherit_config,
|
inherit_config,
|
||||||
]
|
]
|
||||||
@ -269,7 +269,7 @@ def test_get_load_balancing_config_should_return_none_when_config_not_found(
|
|||||||
# Arrange
|
# Arrange
|
||||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||||
@ -289,7 +289,7 @@ def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_
|
|||||||
}
|
}
|
||||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||||
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
|
config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True)
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = config
|
mock_db.session.scalar.return_value = config
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1")
|
||||||
@ -389,7 +389,7 @@ def test_update_load_balancing_configs_should_raise_value_error_when_credential_
|
|||||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||||
mock_db.session.scalars.return_value.all.return_value = []
|
mock_db.session.scalars.return_value.all.return_value = []
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
# Act + Assert
|
# Act + Assert
|
||||||
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
|
with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"):
|
||||||
@ -578,7 +578,7 @@ def test_update_load_balancing_configs_should_create_from_existing_provider_cred
|
|||||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||||
mock_db.session.scalars.return_value.all.return_value = []
|
mock_db.session.scalars.return_value.all.return_value = []
|
||||||
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
|
credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}')
|
||||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record
|
mock_db.session.scalar.return_value = credential_record
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
service.update_load_balancing_configs(
|
service.update_load_balancing_configs(
|
||||||
@ -623,7 +623,7 @@ def test_validate_load_balancing_credentials_should_raise_value_error_when_confi
|
|||||||
# Arrange
|
# Arrange
|
||||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
mock_db.session.scalar.return_value = None
|
||||||
|
|
||||||
# Act + Assert
|
# Act + Assert
|
||||||
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
|
with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"):
|
||||||
@ -646,7 +646,7 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_
|
|||||||
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema())
|
||||||
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
service.provider_manager.get_configurations.return_value = {"openai": provider_configuration}
|
||||||
existing_config = SimpleNamespace(id="cfg-1")
|
existing_config = SimpleNamespace(id="cfg-1")
|
||||||
mock_db.session.query.return_value.where.return_value.first.return_value = existing_config
|
mock_db.session.scalar.return_value = existing_config
|
||||||
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
|
mock_validate = mocker.patch.object(service, "_custom_credentials_validate")
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user