diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index de3b0964ff..111de89178 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -205,16 +205,10 @@ class ProviderConfiguration(BaseModel): """ Get custom provider record. """ - # get provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - stmt = select(Provider).where( Provider.tenant_id == self.tenant_id, Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name.in_(provider_names), + Provider.provider_name.in_(self._get_provider_names()), ) return session.execute(stmt).scalar_one_or_none() @@ -276,7 +270,7 @@ class ProviderConfiguration(BaseModel): """ stmt = select(ProviderCredential.id).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.credential_name == credential_name, ) if exclude_id: @@ -324,7 +318,7 @@ class ProviderConfiguration(BaseModel): try: stmt = select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ProviderCredential.id == credential_id, ) credential_record = s.execute(stmt).scalar_one_or_none() @@ -374,7 +368,7 @@ class ProviderConfiguration(BaseModel): session=session, query_factory=lambda: select(ProviderCredential).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ), ) @@ -387,7 +381,7 @@ class ProviderConfiguration(BaseModel): session=session, query_factory=lambda: select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ), @@ -423,6 +417,16 @@ class ProviderConfiguration(BaseModel): logger.warning("Error generating next credential name: %s", str(e)) return "API KEY 1" + def _get_provider_names(self): + """ + The provider name might be stored in the database as either `openai` or `langgenius/openai/openai`. + """ + model_provider_id = ModelProviderID(self.provider.provider) + provider_names = [self.provider.provider] + if model_provider_id.is_langgenius(): + provider_names.append(model_provider_id.provider_name) + return provider_names + def create_provider_credential(self, credentials: dict, credential_name: str | None): """ Add custom provider credentials. @@ -501,7 +505,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) # Get the credential record to update @@ -554,7 +558,7 @@ class ProviderConfiguration(BaseModel): # Find all load balancing configs that use this credential_id stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == credential_source, ) @@ -591,7 +595,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) # Get the credential record to update @@ -602,7 +606,7 @@ class ProviderConfiguration(BaseModel): # Check if this credential is used in load balancing configs lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "provider", ) @@ -624,7 +628,7 @@ class ProviderConfiguration(BaseModel): # if this is the last credential, we need to delete the provider record count_stmt = select(func.count(ProviderCredential.id)).where( ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) available_credentials_count = session.execute(count_stmt).scalar() or 0 session.delete(credential_record) @@ -668,7 +672,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderCredential).where( ProviderCredential.id == credential_id, ProviderCredential.tenant_id == self.tenant_id, - ProviderCredential.provider_name == self.provider.provider, + ProviderCredential.provider_name.in_(self._get_provider_names()), ) credential_record = session.execute(stmt).scalar_one_or_none() if not credential_record: @@ -737,7 +741,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -784,7 +788,7 @@ class ProviderConfiguration(BaseModel): """ stmt = select(ProviderModelCredential).where( ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ProviderModelCredential.credential_name == credential_name, @@ -860,7 +864,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -997,7 +1001,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1042,7 +1046,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1052,7 +1056,7 @@ class ProviderConfiguration(BaseModel): lb_stmt = select(LoadBalancingModelConfig).where( LoadBalancingModelConfig.tenant_id == self.tenant_id, - LoadBalancingModelConfig.provider_name == self.provider.provider, + LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, LoadBalancingModelConfig.credential_source_type == "custom_model", ) @@ -1075,7 +1079,7 @@ class ProviderConfiguration(BaseModel): # if this is the last credential, we need to delete the custom model record count_stmt = select(func.count(ProviderModelCredential.id)).where( ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1115,7 +1119,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1157,7 +1161,7 @@ class ProviderConfiguration(BaseModel): stmt = select(ProviderModelCredential).where( ProviderModelCredential.id == credential_id, ProviderModelCredential.tenant_id == self.tenant_id, - ProviderModelCredential.provider_name == self.provider.provider, + ProviderModelCredential.provider_name.in_(self._get_provider_names()), ProviderModelCredential.model_name == model, ProviderModelCredential.model_type == model_type.to_origin_model_type(), ) @@ -1204,15 +1208,9 @@ class ProviderConfiguration(BaseModel): """ Get provider model setting. """ - - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - stmt = select(ProviderModelSetting).where( ProviderModelSetting.tenant_id == self.tenant_id, - ProviderModelSetting.provider_name.in_(provider_names), + ProviderModelSetting.provider_name.in_(self._get_provider_names()), ProviderModelSetting.model_type == model_type.to_origin_model_type(), ProviderModelSetting.model_name == model, ) @@ -1384,15 +1382,9 @@ class ProviderConfiguration(BaseModel): return def _switch(s: Session): - # get preferred provider - model_provider_id = ModelProviderID(self.provider.provider) - provider_names = [self.provider.provider] - if model_provider_id.is_langgenius(): - provider_names.append(model_provider_id.provider_name) - stmt = select(TenantPreferredModelProvider).where( TenantPreferredModelProvider.tenant_id == self.tenant_id, - TenantPreferredModelProvider.provider_name.in_(provider_names), + TenantPreferredModelProvider.provider_name.in_(self._get_provider_names()), ) preferred_model_provider = s.execute(stmt).scalars().first() diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6f642ab5db..499d39bd5d 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -513,6 +513,21 @@ class ProviderManager: return provider_name_to_provider_load_balancing_model_configs_dict + @staticmethod + def _get_provider_names(provider_name: str) -> list[str]: + """ + provider_name: `openai` or `langgenius/openai/openai` + return: [`openai`, `langgenius/openai/openai`] + """ + provider_names = [provider_name] + model_provider_id = ModelProviderID(provider_name) + if model_provider_id.is_langgenius(): + if "/" in provider_name: + provider_names.append(model_provider_id.provider_name) + else: + provider_names.append(str(model_provider_id)) + return provider_names + @staticmethod def get_provider_available_credentials(tenant_id: str, provider_name: str) -> list[CredentialConfiguration]: """ @@ -525,7 +540,10 @@ class ProviderManager: with Session(db.engine, expire_on_commit=False) as session: stmt = ( select(ProviderCredential) - .where(ProviderCredential.tenant_id == tenant_id, ProviderCredential.provider_name == provider_name) + .where( + ProviderCredential.tenant_id == tenant_id, + ProviderCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)), + ) .order_by(ProviderCredential.created_at.desc()) ) @@ -554,7 +572,7 @@ class ProviderManager: select(ProviderModelCredential) .where( ProviderModelCredential.tenant_id == tenant_id, - ProviderModelCredential.provider_name == provider_name, + ProviderModelCredential.provider_name.in_(ProviderManager._get_provider_names(provider_name)), ProviderModelCredential.model_name == model_name, ProviderModelCredential.model_type == model_type, )