mirror of https://github.com/langgenius/dify.git
fix: Ensure compatibility with old provider name when updating model credentials (#26017)
This commit is contained in:
parent
24e8d21b3f
commit
ef80d3b707
|
|
@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Get custom provider record.
|
Get custom provider record.
|
||||||
"""
|
"""
|
||||||
# get provider
|
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
|
||||||
provider_names = [self.provider.provider]
|
|
||||||
if model_provider_id.is_langgenius():
|
|
||||||
provider_names.append(model_provider_id.provider_name)
|
|
||||||
|
|
||||||
stmt = select(Provider).where(
|
stmt = select(Provider).where(
|
||||||
Provider.tenant_id == self.tenant_id,
|
Provider.tenant_id == self.tenant_id,
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value,
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||||
Provider.provider_name.in_(provider_names),
|
Provider.provider_name.in_(self._get_provider_names()),
|
||||||
)
|
)
|
||||||
|
|
||||||
return session.execute(stmt).scalar_one_or_none()
|
return session.execute(stmt).scalar_one_or_none()
|
||||||
|
|
@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
stmt = select(ProviderCredential.id).where(
|
stmt = select(ProviderCredential.id).where(
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderCredential.credential_name == credential_name,
|
ProviderCredential.credential_name == credential_name,
|
||||||
)
|
)
|
||||||
if exclude_id:
|
if exclude_id:
|
||||||
|
|
@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
try:
|
try:
|
||||||
stmt = select(ProviderCredential).where(
|
stmt = select(ProviderCredential).where(
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderCredential.id == credential_id,
|
ProviderCredential.id == credential_id,
|
||||||
)
|
)
|
||||||
credential_record = s.execute(stmt).scalar_one_or_none()
|
credential_record = s.execute(stmt).scalar_one_or_none()
|
||||||
|
|
@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
session=session,
|
session=session,
|
||||||
query_factory=lambda: select(ProviderCredential).where(
|
query_factory=lambda: select(ProviderCredential).where(
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
session=session,
|
session=session,
|
||||||
query_factory=lambda: select(ProviderModelCredential).where(
|
query_factory=lambda: select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
),
|
),
|
||||||
|
|
@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel):
|
||||||
logger.warning("Error generating next credential name: %s", str(e))
|
logger.warning("Error generating next credential name: %s", str(e))
|
||||||
return "API KEY 1"
|
return "API KEY 1"
|
||||||
|
|
||||||
|
def _get_provider_names(self):
|
||||||
|
"""
|
||||||
|
The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`.
|
||||||
|
"""
|
||||||
|
model_provider_id = ModelProviderID(self.provider.provider)
|
||||||
|
provider_names = [self.provider.provider]
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
return provider_names
|
||||||
|
|
||||||
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
def create_provider_credential(self, credentials: dict, credential_name: str | None):
|
||||||
"""
|
"""
|
||||||
Add custom provider credentials.
|
Add custom provider credentials.
|
||||||
|
|
@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderCredential).where(
|
stmt = select(ProviderCredential).where(
|
||||||
ProviderCredential.id == credential_id,
|
ProviderCredential.id == credential_id,
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the credential record to update
|
# Get the credential record to update
|
||||||
|
|
@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
# Find all load balancing configs that use this credential_id
|
# Find all load balancing configs that use this credential_id
|
||||||
stmt = select(LoadBalancingModelConfig).where(
|
stmt = select(LoadBalancingModelConfig).where(
|
||||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||||
LoadBalancingModelConfig.credential_id == credential_id,
|
LoadBalancingModelConfig.credential_id == credential_id,
|
||||||
LoadBalancingModelConfig.credential_source_type == credential_source,
|
LoadBalancingModelConfig.credential_source_type == credential_source,
|
||||||
)
|
)
|
||||||
|
|
@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderCredential).where(
|
stmt = select(ProviderCredential).where(
|
||||||
ProviderCredential.id == credential_id,
|
ProviderCredential.id == credential_id,
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the credential record to update
|
# Get the credential record to update
|
||||||
|
|
@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
# Check if this credential is used in load balancing configs
|
# Check if this credential is used in load balancing configs
|
||||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||||
LoadBalancingModelConfig.credential_id == credential_id,
|
LoadBalancingModelConfig.credential_id == credential_id,
|
||||||
LoadBalancingModelConfig.credential_source_type == "provider",
|
LoadBalancingModelConfig.credential_source_type == "provider",
|
||||||
)
|
)
|
||||||
|
|
@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
# if this is the last credential, we need to delete the provider record
|
# if this is the last credential, we need to delete the provider record
|
||||||
count_stmt = select(func.count(ProviderCredential.id)).where(
|
count_stmt = select(func.count(ProviderCredential.id)).where(
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
)
|
)
|
||||||
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
available_credentials_count = session.execute(count_stmt).scalar() or 0
|
||||||
session.delete(credential_record)
|
session.delete(credential_record)
|
||||||
|
|
@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderCredential).where(
|
stmt = select(ProviderCredential).where(
|
||||||
ProviderCredential.id == credential_id,
|
ProviderCredential.id == credential_id,
|
||||||
ProviderCredential.tenant_id == self.tenant_id,
|
ProviderCredential.tenant_id == self.tenant_id,
|
||||||
ProviderCredential.provider_name == self.provider.provider,
|
ProviderCredential.provider_name.in_(self._get_provider_names()),
|
||||||
)
|
)
|
||||||
credential_record = session.execute(stmt).scalar_one_or_none()
|
credential_record = session.execute(stmt).scalar_one_or_none()
|
||||||
if not credential_record:
|
if not credential_record:
|
||||||
|
|
@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.id == credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
ProviderModelCredential.credential_name == credential_name,
|
ProviderModelCredential.credential_name == credential_name,
|
||||||
|
|
@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.id == credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.id == credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.id == credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
|
|
||||||
lb_stmt = select(LoadBalancingModelConfig).where(
|
lb_stmt = select(LoadBalancingModelConfig).where(
|
||||||
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
||||||
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()),
|
||||||
LoadBalancingModelConfig.credential_id == credential_id,
|
LoadBalancingModelConfig.credential_id == credential_id,
|
||||||
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
LoadBalancingModelConfig.credential_source_type == "custom_model",
|
||||||
)
|
)
|
||||||
|
|
@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
# if this is the last credential, we need to delete the custom model record
|
# if this is the last credential, we need to delete the custom model record
|
||||||
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
count_stmt = select(func.count(ProviderModelCredential.id)).where(
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.id == credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel):
|
||||||
stmt = select(ProviderModelCredential).where(
|
stmt = select(ProviderModelCredential).where(
|
||||||
ProviderModelCredential.id == credential_id,
|
ProviderModelCredential.id == credential_id,
|
||||||
ProviderModelCredential.tenant_id == self.tenant_id,
|
ProviderModelCredential.tenant_id == self.tenant_id,
|
||||||
ProviderModelCredential.provider_name == self.provider.provider,
|
ProviderModelCredential.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelCredential.model_name == model,
|
ProviderModelCredential.model_name == model,
|
||||||
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
ProviderModelCredential.model_type == model_type.to_origin_model_type(),
|
||||||
)
|
)
|
||||||
|
|
@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel):
|
||||||
"""
|
"""
|
||||||
Get provider model setting.
|
Get provider model setting.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
|
||||||
provider_names = [self.provider.provider]
|
|
||||||
if model_provider_id.is_langgenius():
|
|
||||||
provider_names.append(model_provider_id.provider_name)
|
|
||||||
|
|
||||||
stmt = select(ProviderModelSetting).where(
|
stmt = select(ProviderModelSetting).where(
|
||||||
ProviderModelSetting.tenant_id == self.tenant_id,
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
||||||
ProviderModelSetting.provider_name.in_(provider_names),
|
ProviderModelSetting.provider_name.in_(self._get_provider_names()),
|
||||||
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
||||||
ProviderModelSetting.model_name == model,
|
ProviderModelSetting.model_name == model,
|
||||||
)
|
)
|
||||||
|
|
@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel):
|
||||||
return
|
return
|
||||||
|
|
||||||
def _switch(s: Session):
|
def _switch(s: Session):
|
||||||
# get preferred provider
|
|
||||||
model_provider_id = ModelProviderID(self.provider.provider)
|
|
||||||
provider_names = [self.provider.provider]
|
|
||||||
if model_provider_id.is_langgenius():
|
|
||||||
provider_names.append(model_provider_id.provider_name)
|
|
||||||
|
|
||||||
stmt = select(TenantPreferredModelProvider).where(
|
stmt = select(TenantPreferredModelProvider).where(
|
||||||
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
||||||
TenantPreferredModelProvider.provider_name.in_(provider_names),
|
TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()),
|
||||||
)
|
)
|
||||||
preferred_model_provider = s.execute(stmt).scalars().first()
|
preferred_model_provider = s.execute(stmt).scalars().first()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -513,6 +513,21 @@ class ProviderManager:
|
||||||
|
|
||||||
return provider_name_to_provider_load_balancing_model_configs_dict
|
return provider_name_to_provider_load_balancing_model_configs_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_provider_names(provider_name: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
provider_name: `openai` or `langgenius/openai/openai`
|
||||||
|
return: [`openai`, `langgenius/openai/openai`]
|
||||||
|
"""
|
||||||
|
provider_names = [provider_name]
|
||||||
|
model_provider_id = ModelProviderID(provider_name)
|
||||||
|
if model_provider_id.is_langgenius():
|
||||||
|
if "/" in provider_name:
|
||||||
|
provider_names.append(model_provider_id.provider_name)
|
||||||
|
else:
|
||||||
|
provider_names.append(str(model_provider_id))
|
||||||
|
return provider_names
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -525,7 +540,10 @@ class ProviderManager:
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
select(ProviderCredential)
|
select(ProviderCredential)
|
||||||
.where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name)
|
.where(
|
||||||
|
ProviderCredential.tenant_id == tenant_id,
|
||||||
|
ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||||
|
)
|
||||||
.order_by(ProviderCredential.created_at.desc())
|
.order_by(ProviderCredential.created_at.desc())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -554,7 +572,7 @@ class ProviderManager:
|
||||||
select(ProviderModelCredential)
|
select(ProviderModelCredential)
|
||||||
.where(
|
.where(
|
||||||
ProviderModelCredential.tenant_id == tenant_id,
|
ProviderModelCredential.tenant_id == tenant_id,
|
||||||
ProviderModelCredential.provider_name == provider_name,
|
ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)),
|
||||||
ProviderModelCredential.model_name == model_name,
|
ProviderModelCredential.model_name == model_name,
|
||||||
ProviderModelCredential.model_type == model_type,
|
ProviderModelCredential.model_type == model_type,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue