mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor(api): type load balancing config dicts with TypedDict (#34639)
This commit is contained in:
parent
19c80f0f0e
commit
63db9a7a2f
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
from typing import Any, TypedDict, Union
|
||||
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.entities.provider_entities import (
|
||||
@ -25,6 +25,23 @@ from models.provider import LoadBalancingModelConfig, ProviderCredential, Provid
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadBalancingConfigDetailDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict[str, Any]
|
||||
enabled: bool
|
||||
|
||||
|
||||
class LoadBalancingConfigSummaryDict(TypedDict):
|
||||
id: str
|
||||
name: str
|
||||
credentials: dict[str, Any]
|
||||
credential_id: str | None
|
||||
enabled: bool
|
||||
in_cooldown: bool
|
||||
ttl: int
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
@staticmethod
|
||||
def _get_provider_manager(tenant_id: str) -> ProviderManager:
|
||||
@ -74,7 +91,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_from: str = ""
|
||||
) -> tuple[bool, list[dict]]:
|
||||
) -> tuple[bool, list[LoadBalancingConfigSummaryDict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -156,7 +173,7 @@ class ModelLoadBalancingService:
|
||||
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||
|
||||
# fetch status and ttl for each config
|
||||
datas = []
|
||||
datas: list[LoadBalancingConfigSummaryDict] = []
|
||||
for load_balancing_config in load_balancing_configs:
|
||||
in_cooldown, ttl = LBModelManager.get_config_in_cooldown_and_ttl(
|
||||
tenant_id=tenant_id,
|
||||
@ -214,7 +231,7 @@ class ModelLoadBalancingService:
|
||||
|
||||
def get_load_balancing_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
|
||||
) -> dict | None:
|
||||
) -> LoadBalancingConfigDetailDict | None:
|
||||
"""
|
||||
Get load balancing configuration.
|
||||
:param tenant_id: workspace id
|
||||
@ -267,12 +284,13 @@ class ModelLoadBalancingService:
|
||||
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
)
|
||||
|
||||
return {
|
||||
result: LoadBalancingConfigDetailDict = {
|
||||
"id": load_balancing_model_config.id,
|
||||
"name": load_balancing_model_config.name,
|
||||
"credentials": credentials,
|
||||
"enabled": load_balancing_model_config.enabled,
|
||||
}
|
||||
return result
|
||||
|
||||
def _init_inherit_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: ModelType
|
||||
|
||||
@ -635,7 +635,7 @@ class WorkflowService:
|
||||
# If we can't determine the status, assume load balancing is not enabled
|
||||
return False
|
||||
|
||||
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict]:
|
||||
def _get_load_balancing_configs(self, tenant_id: str, provider: str, model_name: str) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all load balancing configurations for a model.
|
||||
|
||||
@ -659,7 +659,7 @@ class WorkflowService:
|
||||
_, custom_configs = model_load_balancing_service.get_load_balancing_configs(
|
||||
tenant_id=tenant_id, provider=provider, model=model_name, model_type="llm", config_from="custom-model"
|
||||
)
|
||||
all_configs = configs + custom_configs
|
||||
all_configs = cast(list[dict[str, Any]], configs) + cast(list[dict[str, Any]], custom_configs)
|
||||
|
||||
return [config for config in all_configs if config.get("credential_id")]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user