fix: Ensure compatibility with old provider name when updating model credentials (#26017)

This commit is contained in:
非法操作 2025-09-22 19:39:17 +08:00 committed by GitHub
parent 24e8d21b3f
commit ef80d3b707
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 42 deletions

View File

@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
""" """
Get custom provider record. Get custom provider record.
""" """
# get provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(Provider).where( stmt = select(Provider).where(
Provider.tenant_id == self.tenant_id, Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value, Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names), Provider.provider_name.in_(self._get_provider_names()),
) )
return session.execute(stmt).scalar_one_or_none() return session.execute(stmt).scalar_one_or_none()
@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
""" """
stmt = select(ProviderCredential.id).where( stmt = select(ProviderCredential.id).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.credential_name == credential_name, ProviderCredential.credential_name == credential_name,
) )
if exclude_id: if exclude_id:
@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
try: try:
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
) )
credential_record = s.execute(stmt).scalar_one_or_none() credential_record = s.execute(stmt).scalar_one_or_none()
@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
session=session, session=session,
query_factory=lambda: select(ProviderCredential).where( query_factory=lambda: select(ProviderCredential).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
), ),
) )
@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
session=session, session=session,
query_factory=lambda: select(ProviderModelCredential).where( query_factory=lambda: select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
), ),
@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
logger.warning("Error generating next credential name: %s", str(e)) logger.warning("Error generating next credential name: %s", str(e))
return "API KEY 1" return "API KEY 1"
def _get_provider_names(self):
"""
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
"""
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
return provider_names
def create_provider_credential(self, credentials: dict, credential_name: str | None): def create_provider_credential(self, credentials: dict, credential_name: str | None):
""" """
Add custom provider credentials. Add custom provider credentials.
@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
# Get the credential record to update # Get the credential record to update
@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
# Find all load balancing configs that use this credential_id # Find all load balancing configs that use this credential_id
stmt = select(LoadBalancingModelConfig).where( stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == credential_source, LoadBalancingModelConfig.credential_source_type == credential_source,
) )
@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
# Get the credential record to update # Get the credential record to update
@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
# Check if this credential is used in load balancing configs # Check if this credential is used in load balancing configs
lb_stmt = select(LoadBalancingModelConfig).where( lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "provider", LoadBalancingModelConfig.credential_source_type == "provider",
) )
@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the provider record # if this is the last credential, we need to delete the provider record
count_stmt = select(func.count(ProviderCredential.id)).where( count_stmt = select(func.count(ProviderCredential.id)).where(
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
available_credentials_count = session.execute(count_stmt).scalar() or 0 available_credentials_count = session.execute(count_stmt).scalar() or 0
session.delete(credential_record) session.delete(credential_record)
@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderCredential).where( stmt = select(ProviderCredential).where(
ProviderCredential.id == credential_id, ProviderCredential.id == credential_id,
ProviderCredential.tenant_id == self.tenant_id, ProviderCredential.tenant_id == self.tenant_id,
ProviderCredential.provider_name == self.provider.provider, ProviderCredential.provider_name.in_(self._get_provider_names()),
) )
credential_record = session.execute(stmt).scalar_one_or_none() credential_record = session.execute(stmt).scalar_one_or_none()
if not credential_record: if not credential_record:
@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
""" """
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
ProviderModelCredential.credential_name == credential_name, ProviderModelCredential.credential_name == credential_name,
@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
lb_stmt = select(LoadBalancingModelConfig).where( lb_stmt = select(LoadBalancingModelConfig).where(
LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name == self.provider.provider, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_id == credential_id,
LoadBalancingModelConfig.credential_source_type == "custom_model", LoadBalancingModelConfig.credential_source_type == "custom_model",
) )
@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
# if this is the last credential, we need to delete the custom model record # if this is the last credential, we need to delete the custom model record
count_stmt = select(func.count(ProviderModelCredential.id)).where( count_stmt = select(func.count(ProviderModelCredential.id)).where(
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
stmt = select(ProviderModelCredential).where( stmt = select(ProviderModelCredential).where(
ProviderModelCredential.id == credential_id, ProviderModelCredential.id == credential_id,
ProviderModelCredential.tenant_id == self.tenant_id, ProviderModelCredential.tenant_id == self.tenant_id,
ProviderModelCredential.provider_name == self.provider.provider, ProviderModelCredential.provider_name.in_(self._get_provider_names()),
ProviderModelCredential.model_name == model, ProviderModelCredential.model_name == model,
ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.model_type == model_type.to_origin_model_type(),
) )
@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
""" """
Get provider model setting. Get provider model setting.
""" """
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(ProviderModelSetting).where( stmt = select(ProviderModelSetting).where(
ProviderModelSetting.tenant_id == self.tenant_id, ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names), ProviderModelSetting.provider_name.in_(self._get_provider_names()),
ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_type == model_type.to_origin_model_type(),
ProviderModelSetting.model_name == model, ProviderModelSetting.model_name == model,
) )
@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
return return
def _switch(s: Session): def _switch(s: Session):
# get preferred provider
model_provider_id = ModelProviderID(self.provider.provider)
provider_names = [self.provider.provider]
if model_provider_id.is_langgenius():
provider_names.append(model_provider_id.provider_name)
stmt = select(TenantPreferredModelProvider).where( stmt = select(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == self.tenant_id, TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names), TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
) )
preferred_model_provider = s.execute(stmt).scalars().first() preferred_model_provider = s.execute(stmt).scalars().first()

View File

@ -513,6 +513,21 @@ class ProviderManager:
return provider_name_to_provider_load_balancing_model_configs_dict return provider_name_to_provider_load_balancing_model_configs_dict
@staticmethod
def _get_provider_names(provider_name: str) -> list[str]:
"""
provider_name: `openai` or `langgenius/openai/openai`
return: [`openai`, `langgenius/openai/openai`]
"""
provider_names = [provider_name]
model_provider_id = ModelProviderID(provider_name)
if model_provider_id.is_langgenius():
if "/" in provider_name:
provider_names.append(model_provider_id.provider_name)
else:
provider_names.append(str(model_provider_id))
return provider_names
@staticmethod @staticmethod
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
""" """
@ -525,7 +540,10 @@ class ProviderManager:
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
stmt = ( stmt = (
select(ProviderCredential) select(ProviderCredential)
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) .where(
ProviderCredential.tenant_id == tenant_id,
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
)
.order_by(ProviderCredential.created_at.desc()) .order_by(ProviderCredential.created_at.desc())
) )
@ -554,7 +572,7 @@ class ProviderManager:
select(ProviderModelCredential) select(ProviderModelCredential)
.where( .where(
ProviderModelCredential.tenant_id == tenant_id, ProviderModelCredential.tenant_id == tenant_id,
ProviderModelCredential.provider_name == provider_name, ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
ProviderModelCredential.model_name == model_name, ProviderModelCredential.model_name == model_name,
ProviderModelCredential.model_type == model_type, ProviderModelCredential.model_type == model_type,
) )