refactor(api): type load balancing config dicts with TypedDict (#34639)

This commit is contained in:
Statxc 2026-04-07 02:58:10 -03:00 committed by GitHub
parent 19c80f0f0e
commit 63db9a7a2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 7 deletions

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Union
from typing import Any, TypedDict, Union
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.entities.provider_entities import (
@ -25,6 +25,23 @@ from models.provider import LoadBalancingModelConfig, ProviderCredential, Provid
logger = logging.getLogger(__name__)
class LoadBalancingConfigDetailDict(TypedDict):
id: str
name: str
credentials: dict[str, Any]
enabled: bool
class LoadBalancingConfigSummaryDict(TypedDict):
id: str
name: str
credentials: dict[str, Any]
credential_id: str | None
enabled: bool
in_cooldown: bool
ttl: int
class ModelLoadBalancingService:
@staticmethod
def _get_provider_manager(tenant_id: str) -> ProviderManager:
@ -74,7 +91,7 @@ class ModelLoadBalancingService:
def get_load_balancing_configs(
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
) -> tuple[bool, list[dict]]:
) -> tuple[bool, list[LoadBalancingConfigSummaryDict]]:
"""
Get load balancing configurations.
:param tenant_id: workspace id
@ -156,7 +173,7 @@ class ModelLoadBalancingService:
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
# fetch status and ttl for each config
datas = []
datas: list[LoadBalancingConfigSummaryDict] = []
for load_balancing_config in load_balancing_configs:
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
tenant_id=tenant_id,
@ -214,7 +231,7 @@ class ModelLoadBalancingService:
def get_load_balancing_config(
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
) -> dict | None:
) -> LoadBalancingConfigDetailDict | None:
"""
Get load balancing configuration.
:param tenant_id: workspace id
@ -267,12 +284,13 @@ class ModelLoadBalancingService:
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
)
return {
result: LoadBalancingConfigDetailDict = {
"id": load_balancing_model_config.id,
"name": load_balancing_model_config.name,
"credentials": credentials,
"enabled": load_balancing_model_config.enabled,
}
return result
def _init_inherit_config(
self, tenant_id: str, provider: str, model: str, model_type: ModelType

View File

@ -635,7 +635,7 @@ class WorkflowService:
# If we can't determine the status, assume load balancing is not enabled
return False
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict[str, Any]]:
"""
Get all load balancing configurations for a model.
@ -659,7 +659,7 @@ class WorkflowService:
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
)
all_configs = configs + custom_configs
all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs)
return [config for config in all_configs if config.get("credential_id")]