This commit is contained in:
hjlarry 2025-08-19 10:57:04 +08:00
parent ad37863183
commit d4c2003450
3 changed files with 29 additions and 24 deletions

View File

@ -4,7 +4,7 @@ import logging
from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from typing import Optional
from typing import Optional, Any
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import func, select
@ -270,7 +270,7 @@ class ProviderConfiguration(BaseModel):
return self._get_specific_provider_credential(credential_id)
# Default behavior: return current active provider credentials
credentials = self.custom_configuration.provider.credentials
credentials = self.custom_configuration.provider.credentials if self.custom_configuration.provider else {}
return self.obfuscated_credentials(
credentials=credentials,
@ -290,7 +290,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
def _validate(s: Session) -> tuple[Provider | None, dict]:
def _validate(s: Session) -> dict:
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.provider_credential_schema.credential_form_schemas
@ -452,7 +452,11 @@ class ProviderConfiguration(BaseModel):
raise
def _update_load_balancing_configs_with_credential(
self, credential_id: str, credential_record: dict, credential_source: str, session: Session
self,
credential_id: str,
credential_record: ProviderCredential | ProviderModelCredential,
credential_source: str,
session: Session,
) -> None:
"""
Update load balancing configurations that reference the given credential_id.
@ -513,13 +517,13 @@ class ProviderConfiguration(BaseModel):
raise ValueError("Credential record not found.")
# Check if this credential is used in load balancing configs
stmt = select(LoadBalancingModelConfig).where(
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider",
)
lb_configs_using_credential = session.execute(stmt).scalars().all()
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
for lb_config in lb_configs_using_credential:
lb_credentials_cache = ProviderCredentialsCache(
@ -541,11 +545,11 @@ class ProviderConfiguration(BaseModel):
# Check available credentials count BEFORE deleting
# if this is the last credential, we need to delete the provider record
stmt = select(func.count(ProviderCredential.id)).where(
count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider,
)
available_credentials_count = session.execute(stmt).scalar()
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
if provider_record and available_credentials_count <= 1:
@ -644,14 +648,12 @@ class ProviderConfiguration(BaseModel):
)
with Session(db.engine) as session:
model_record = self._get_custom_model_record(model_type=model_type, model=model, session=session)
stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == model_record.provider_name,
ProviderModelCredential.model_name == model_record.model_name,
ProviderModelCredential.model_type == model_record.model_type,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
credential = session.execute(stmt).scalar_one_or_none()
@ -722,6 +724,7 @@ class ProviderConfiguration(BaseModel):
if self.provider.model_credential_schema
else [],
)
return None
def validate_custom_model_credentials(
self,
@ -730,7 +733,7 @@ class ProviderConfiguration(BaseModel):
credentials: dict,
credential_id: str = "",
session: Session | None = None,
) -> tuple[ProviderModel | None, dict]:
) -> dict:
"""
Validate custom model credentials.
@ -741,7 +744,7 @@ class ProviderConfiguration(BaseModel):
:return:
"""
def _validate(s: Session) -> tuple[ProviderModel | None, dict]:
def _validate(s: Session) -> dict:
# Get provider credential secret variables
provider_credential_secret_variables = self.extract_secret_variables(
self.provider.model_credential_schema.credential_form_schemas
@ -938,13 +941,13 @@ class ProviderConfiguration(BaseModel):
if not credential_record:
raise ValueError("Credential record not found.")
stmt = select(LoadBalancingModelConfig).where(
lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider,
LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model",
)
lb_configs_using_credential = session.execute(stmt).scalars().all()
lb_configs_using_credential = session.execute(lb_stmt).scalars().all()
try:
for lb_config in lb_configs_using_credential:
@ -966,13 +969,13 @@ class ProviderConfiguration(BaseModel):
# Check available credentials count BEFORE deleting
# if this is the last credential, we need to delete the custom model record
stmt = select(func.count(ProviderModelCredential.id)).where(
count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider,
ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
)
available_credentials_count = session.execute(stmt).scalar()
available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record)
if provider_model_record and available_credentials_count <= 1:
@ -1161,7 +1164,7 @@ class ProviderConfiguration(BaseModel):
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
LoadBalancingModelConfig.model_name == model,
)
load_balancing_config_count = session.execute(stmt).scalar()
load_balancing_config_count = session.execute(stmt).scalar() or 0
if load_balancing_config_count <= 1:
raise ValueError("Model load balancing configuration must be more than 1.")
@ -1264,8 +1267,8 @@ class ProviderConfiguration(BaseModel):
provider_name=self.provider.provider,
preferred_provider_type=provider_type.value,
)
session.add(preferred_model_provider)
session.commit()
s.add(preferred_model_provider)
s.commit()
if session:
return _switch(session)

View File

@ -333,6 +333,7 @@ class ModelLoadBalancingService:
enabled = config.get("enabled")
if credential_id:
credential_record: ProviderCredential | ProviderModelCredential | None = None
if config_from == "predefined-model":
credential_record = (
db.session.query(ProviderCredential)
@ -407,6 +408,7 @@ class ModelLoadBalancingService:
if credential_id:
credential_source = "provider" if config_from == "predefined-model" else "custom_model"
assert credential_record is not None
load_balancing_model_config = LoadBalancingModelConfig(
tenant_id=tenant_id,
provider_name=provider_configuration.provider.provider,

View File

@ -138,7 +138,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict) -> None:
"""
@ -226,7 +226,7 @@ class ModelProviderService:
:return:
"""
provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential(
return provider_configuration.get_custom_model_credential( # type: ignore
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
)